![gal@spitfire.co.il](/assets/img/avatar_default.png)
![GitHub](/assets/img/avatar_default.png)
6 changed files with 161 additions and 198 deletions
@ -1,11 +1,40 @@
@@ -1,11 +1,40 @@
|
||||
#[cfg(unix)] |
||||
pub mod unix; |
||||
|
||||
#[cfg(windows)] |
||||
pub mod windows; |
||||
use super::{Bash, Fish, PowerShell, Shell, WindowsCmd, Zsh}; |
||||
use log::debug; |
||||
use std::ffi::OsStr; |
||||
use sysinfo::{ProcessExt, System, SystemExt}; |
||||
|
||||
#[derive(Debug)] |
||||
struct ProcessInfo { |
||||
parent_pid: Option<u32>, |
||||
command: String, |
||||
} |
||||
|
||||
pub fn infer_shell() -> Option<Box<dyn Shell>> { |
||||
let system = System::new_all(); |
||||
let hashmap = system.processes(); |
||||
let mut current_pid = sysinfo::get_current_pid().ok(); |
||||
|
||||
while let Some(pid) = current_pid { |
||||
if let Some(process) = hashmap.get(&pid) { |
||||
current_pid = process.parent(); |
||||
let process_name = process |
||||
.exe() |
||||
.file_stem() |
||||
.and_then(OsStr::to_str) |
||||
.map(str::to_lowercase); |
||||
let sliced = process_name.as_ref().map(|x| &x[..]); |
||||
match sliced { |
||||
Some("sh" | "bash") => return Some(Box::from(Bash)), |
||||
Some("zsh") => return Some(Box::from(Zsh)), |
||||
Some("fish") => return Some(Box::from(Fish)), |
||||
Some("pwsh" | "powershell") => return Some(Box::from(PowerShell)), |
||||
Some("cmd") => return Some(Box::from(WindowsCmd)), |
||||
cmd_name => debug!("binary is not a supported shell: {:?}", cmd_name), |
||||
}; |
||||
} else { |
||||
current_pid = None; |
||||
} |
||||
} |
||||
|
||||
None |
||||
} |
||||
|
@ -1,99 +0,0 @@
@@ -1,99 +0,0 @@
|
||||
#![cfg(unix)] |
||||
|
||||
use super::super::{Bash, Fish, PowerShell, Shell, Zsh}; |
||||
use log::debug; |
||||
use std::io::{Error, ErrorKind}; |
||||
|
||||
#[derive(Debug)] |
||||
struct ProcessInfo { |
||||
parent_pid: Option<u32>, |
||||
command: String, |
||||
} |
||||
|
||||
const MAX_ITERATIONS: u8 = 10; |
||||
|
||||
pub fn infer_shell() -> Option<Box<dyn Shell>> { |
||||
let mut pid = Some(std::process::id()); |
||||
let mut visited = 0; |
||||
|
||||
while pid != None && visited < MAX_ITERATIONS { |
||||
let process_info = get_process_info(pid.unwrap()).ok()?; |
||||
let binary = process_info |
||||
.command |
||||
.trim_start_matches('-') |
||||
.split('/') |
||||
.last() |
||||
.expect("Can't read file name of process tree"); |
||||
|
||||
match binary { |
||||
"sh" | "bash" => return Some(Box::from(Bash)), |
||||
"zsh" => return Some(Box::from(Zsh)), |
||||
"fish" => return Some(Box::from(Fish)), |
||||
"pwsh" => return Some(Box::from(PowerShell)), |
||||
cmd_name => debug!("binary is not a supported shell: {:?}", cmd_name), |
||||
}; |
||||
|
||||
pid = process_info.parent_pid; |
||||
visited += 1; |
||||
} |
||||
|
||||
None |
||||
} |
||||
|
||||
fn get_process_info(pid: u32) -> std::io::Result<ProcessInfo> { |
||||
use std::io::{BufRead, BufReader}; |
||||
use std::process::Command; |
||||
|
||||
let buffer = Command::new("ps") |
||||
.arg("-o") |
||||
.arg("ppid,comm") |
||||
.arg(pid.to_string()) |
||||
.stdout(std::process::Stdio::piped()) |
||||
.spawn()? |
||||
.stdout |
||||
.ok_or_else(|| Error::from(ErrorKind::UnexpectedEof))?; |
||||
|
||||
let mut lines = BufReader::new(buffer).lines(); |
||||
|
||||
// skip header line
|
||||
lines |
||||
.next() |
||||
.ok_or_else(|| Error::from(ErrorKind::UnexpectedEof))??; |
||||
|
||||
let line = lines |
||||
.next() |
||||
.ok_or_else(|| Error::from(ErrorKind::NotFound))??; |
||||
|
||||
let mut parts = line.trim().split_whitespace(); |
||||
let ppid = parts |
||||
.next() |
||||
.expect("Can't read the ppid from ps, should be the first item in the table"); |
||||
let command = parts |
||||
.next() |
||||
.expect("Can't read the command from ps, should be the second item in the table"); |
||||
|
||||
Ok(ProcessInfo { |
||||
parent_pid: ppid.parse().ok(), |
||||
command: command.into(), |
||||
}) |
||||
} |
||||
|
||||
#[cfg(all(test, unix))] |
||||
mod tests { |
||||
use super::*; |
||||
use pretty_assertions::assert_eq; |
||||
use std::process::{Command, Stdio}; |
||||
|
||||
#[test] |
||||
fn test_get_process_info() { |
||||
let subprocess = Command::new("bash") |
||||
.stdin(Stdio::piped()) |
||||
.stdout(Stdio::piped()) |
||||
.stderr(Stdio::piped()) |
||||
.spawn() |
||||
.expect("Can't execute command"); |
||||
let process_info = get_process_info(subprocess.id()); |
||||
let parent_pid = process_info.ok().and_then(|x| x.parent_pid); |
||||
assert_eq!(parent_pid, Some(std::process::id())); |
||||
} |
||||
} |
@ -1,84 +0,0 @@
@@ -1,84 +0,0 @@
|
||||
#![cfg(windows)] |
||||
|
||||
use super::super::{Bash, PowerShell, Shell, WindowsCmd}; |
||||
use serde::Deserialize; |
||||
use std::collections::HashMap; |
||||
|
||||
#[derive(Deserialize, Debug)] |
||||
pub struct ProcessInfo { |
||||
#[serde(rename = "ExecutablePath")] |
||||
executable_path: Option<std::path::PathBuf>, |
||||
#[serde(rename = "ParentProcessId")] |
||||
parent_pid: u32, |
||||
#[serde(rename = "ProcessId")] |
||||
pid: u32, |
||||
} |
||||
|
||||
pub fn infer_shell() -> Option<Box<dyn Shell>> { |
||||
let process_map = get_process_map().ok()?; |
||||
let process_tree = get_process_tree(process_map, std::process::id()); |
||||
|
||||
for process in process_tree { |
||||
if let Some(exec_path) = process.executable_path { |
||||
match exec_path.file_name().and_then(|x| x.to_str()) { |
||||
Some("cmd.exe") => { |
||||
return Some(Box::from(WindowsCmd)); |
||||
} |
||||
Some("bash.exe") => { |
||||
return Some(Box::from(Bash)); |
||||
} |
||||
Some("powershell.exe") | Some("pwsh.exe") => { |
||||
return Some(Box::from(PowerShell)); |
||||
} |
||||
_ => {} |
||||
} |
||||
} |
||||
} |
||||
|
||||
None |
||||
} |
||||
|
||||
type ProcessMap = HashMap<u32, ProcessInfo>; |
||||
|
||||
pub fn get_process_tree(mut process_map: ProcessMap, pid: u32) -> Vec<ProcessInfo> { |
||||
let mut vec = vec![]; |
||||
let mut current = process_map.remove(&pid); |
||||
|
||||
while let Some(process) = current { |
||||
current = process_map.remove(&process.parent_pid); |
||||
vec.push(process); |
||||
} |
||||
|
||||
vec |
||||
} |
||||
|
||||
pub fn get_process_map() -> std::io::Result<ProcessMap> { |
||||
let stdout = std::process::Command::new("wmic") |
||||
.args(&[ |
||||
"process", |
||||
"get", |
||||
"processid,parentprocessid,executablepath", |
||||
"/format:csv", |
||||
]) |
||||
.stdout(std::process::Stdio::piped()) |
||||
.spawn()? |
||||
.stdout |
||||
.ok_or(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?; |
||||
|
||||
let mut reader = csv::Reader::from_reader(stdout); |
||||
let hashmap: HashMap<_, _> = reader |
||||
.deserialize::<ProcessInfo>() |
||||
.filter_map(Result::ok) |
||||
.map(|x| (x.pid, x)) |
||||
.collect(); |
||||
Ok(hashmap) |
||||
} |
||||
|
||||
#[cfg(test)] |
||||
mod tests { |
||||
#[test] |
||||
fn test_me() { |
||||
let processes = super::get_process_map().unwrap(); |
||||
assert!(processes.contains_key(&std::process::id())); |
||||
} |
||||
} |
Loading…
Reference in new issue