Skip to content
60 changes: 60 additions & 0 deletions uefi-test-runner/src/proto/shell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,65 @@ pub fn test_current_dir(shell: &ScopedProtocol<Shell>) {
assert_eq!(cur_fs_str, expected_fs_str);
}

/// Test `var()`, `vars()`, and `set_var()`
pub fn test_var(shell: &ScopedProtocol<Shell>) {
/* Test retrieving list of environment variable names */
let mut cur_env_vec = shell.vars();
assert_eq!(cur_env_vec.next().unwrap().0, cstr16!("path"));
// check pre-defined shell variables; see UEFI Shell spec
assert_eq!(cur_env_vec.next().unwrap().0, cstr16!("nonesting"));
let cur_env_vec = shell.vars();
let default_len = cur_env_vec.count();

/* Test setting and getting a specific environment variable */
let cur_env_vec = shell.vars();
let test_var = cstr16!("test_var");
let test_val = cstr16!("test_val");
assert!(shell.var(test_var).is_none());
let status = shell.set_var(test_var, test_val, false);
assert!(status.is_ok());
let cur_env_str = shell
.var(test_var)
.expect("Could not get environment variable");
assert_eq!(cur_env_str, test_val);

let mut found_var = false;
for (env_var, _) in cur_env_vec {
if env_var == test_var {
found_var = true;
}
}
assert!(!found_var);
let cur_env_vec = shell.vars();
let mut found_var = false;
for (env_var, _) in cur_env_vec {
if env_var == test_var {
found_var = true;
}
}
assert!(found_var);

let cur_env_vec = shell.vars();
assert_eq!(cur_env_vec.count(), default_len + 1);

/* Test deleting environment variable */
let test_val = cstr16!("");
let status = shell.set_var(test_var, test_val, false);
assert!(status.is_ok());
assert!(shell.var(test_var).is_none());

let cur_env_vec = shell.vars();
let mut found_var = false;
for (env_var, _) in cur_env_vec {
if env_var == test_var {
found_var = true;
}
}
assert!(!found_var);
let cur_env_vec = shell.vars();
assert_eq!(cur_env_vec.count(), default_len);
}

pub fn test() {
info!("Running shell protocol tests");

Expand All @@ -109,4 +168,5 @@ pub fn test() {
boot::open_protocol_exclusive::<Shell>(handle).expect("Failed to open Shell protocol");

test_current_dir(&shell);
test_var(&shell);
}
1 change: 1 addition & 0 deletions uefi/src/proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub mod rng;
#[cfg(feature = "alloc")]
pub mod scsi;
pub mod security;
#[cfg(feature = "alloc")]
pub mod shell;
pub mod shell_params;
pub mod shim;
Expand Down
192 changes: 192 additions & 0 deletions uefi/src/proto/shell/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

use crate::proto::unsafe_protocol;
use crate::{CStr16, Char16, Error, Result, Status, StatusExt};

use core::marker::PhantomData;
use core::ptr;
use uefi_raw::protocol::shell::ShellProtocol;

Expand All @@ -13,6 +15,45 @@ use uefi_raw::protocol::shell::ShellProtocol;
#[unsafe_protocol(ShellProtocol::GUID)]
pub struct Shell(ShellProtocol);

/// Trait for implementing the var function
pub trait ShellVar {
/// Gets the value of the specified environment variable
fn var(&self, name: &CStr16) -> Option<&CStr16>;
}

/// Iterator over the names of environmental variables obtained from the Shell protocol.
#[derive(Debug)]
pub struct Vars<'a, T: ShellVar> {
/// Char16 containing names of environment variables
names: *const Char16,
/// Reference to Shell Protocol
protocol: *const T,
/// Placeholder to attach a lifetime to `Vars`
placeholder: PhantomData<&'a CStr16>,
}

impl<'a, T: ShellVar + 'a> Iterator for Vars<'a, T> {
type Item = (&'a CStr16, Option<&'a CStr16>);
// We iterate a list of NUL terminated CStr16s.
// The list is terminated with a double NUL.
fn next(&mut self) -> Option<Self::Item> {
let s = unsafe { CStr16::from_ptr(self.names) };
if s.is_empty() {
None
} else {
self.names = unsafe { self.names.add(s.num_chars() + 1) };
Some((s, unsafe { self.protocol.as_ref().unwrap().var(s) }))
}
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the implementation could be simplified, something like this:

 fn next(&mut self) -> Option<Self::Item> { let s = unsafe { CStr16::from_ptr(self.inner) }; if s.is_empty() { None } else { self.inner = unsafe { self.inner.add(s.num_chars() + 1) }; Some(s) } }

impl ShellVar for Shell {
/// Gets the value of the specified environment variable
fn var(&self, name: &CStr16) -> Option<&CStr16> {
self.var(name)
}
}

impl Shell {
/// Returns the current directory on the specified device.
///
Expand Down Expand Up @@ -54,4 +95,155 @@ impl Shell {
let dir_ptr: *const Char16 = directory.map_or(ptr::null(), |x| x.as_ptr());
unsafe { (self.0.set_cur_dir)(fs_ptr.cast(), dir_ptr.cast()) }.to_result()
}

/// Gets the value of the specified environment variable
///
/// # Arguments
///
/// * `name` - The environment variable name of which to retrieve the
/// value.
///
/// # Returns
///
/// * `Some(<env_value>)` - &CStr16 containing the value of the
/// environment variable
/// * `None` - If environment variable does not exist
#[must_use]
pub fn var(&self, name: &CStr16) -> Option<&CStr16> {
let name_ptr: *const Char16 = name.as_ptr();
let var_val = unsafe { (self.0.get_env)(name_ptr.cast()) };
if var_val.is_null() {
None
} else {
unsafe { Some(CStr16::from_ptr(var_val.cast())) }
}
}

/// Gets an iterator over the names of all environment variables
#[must_use]
pub fn vars(&self) -> Vars<'_, Self> {
let env_ptr = unsafe { (self.0.get_env)(ptr::null()) };
Vars {
names: env_ptr.cast::<Char16>(),
protocol: self,
placeholder: PhantomData,
}
}

/// Sets the environment variable
///
/// # Arguments
///
/// * `name` - The environment variable for which to set the value
/// * `value` - The new value of the environment variable
/// * `volatile` - Indicates whether the variable is volatile or
/// not
///
/// # Returns
///
/// * `Status::SUCCESS` - The variable was successfully set
pub fn set_var(&self, name: &CStr16, value: &CStr16, volatile: bool) -> Result {
let name_ptr: *const Char16 = name.as_ptr();
let value_ptr: *const Char16 = value.as_ptr();
unsafe { (self.0.set_env)(name_ptr.cast(), value_ptr.cast(), volatile) }.to_result()
}
}

#[cfg(test)]
mod tests {
use super::*;
use alloc::collections::BTreeMap;
use alloc::vec::Vec;
use uefi::cstr16;

struct ShellMock<'a> {
inner: BTreeMap<&'a CStr16, &'a CStr16>,
}

impl<'a> ShellMock<'a> {
fn new(names: Vec<&'a CStr16>, values: Vec<&'a CStr16>) -> ShellMock<'a> {
let mut inner_map = BTreeMap::new();
for (name, val) in names.iter().zip(values.iter()) {
inner_map.insert(*name, *val);
}
ShellMock { inner: inner_map }
}
}
impl<'a> ShellVar for ShellMock<'a> {
fn var(&self, name: &CStr16) -> Option<&CStr16> {
if let Some(val) = self.inner.get(name) {
Some(*val)
} else {
None
}
}
}

/// Testing Vars struct
#[test]
fn test_vars() {
// Empty Vars
let mut vars_mock = Vec::<u16>::new();
vars_mock.push(0);
vars_mock.push(0);
let mut vars = Vars {
names: vars_mock.as_ptr().cast(),
protocol: &ShellMock::new(Vec::new(), Vec::new()),
placeholder: PhantomData,
};

assert!(vars.next().is_none());

// One environment variable in Vars
let mut vars_mock = Vec::<u16>::new();
vars_mock.push(b'f' as u16);
vars_mock.push(b'o' as u16);
vars_mock.push(b'o' as u16);
vars_mock.push(0);
vars_mock.push(0);
let vars = Vars {
names: vars_mock.as_ptr().cast(),
protocol: &ShellMock::new(Vec::from([cstr16!("foo")]), Vec::from([cstr16!("value")])),
placeholder: PhantomData,
};
assert_eq!(
vars.collect::<Vec<_>>(),
Vec::from([(cstr16!("foo"), Some(cstr16!("value")))])
);

// Multiple environment variables in Vars
let mut vars_mock = Vec::<u16>::new();
vars_mock.push(b'f' as u16);
vars_mock.push(b'o' as u16);
vars_mock.push(b'o' as u16);
vars_mock.push(b'1' as u16);
vars_mock.push(0);
vars_mock.push(b'b' as u16);
vars_mock.push(b'a' as u16);
vars_mock.push(b'r' as u16);
vars_mock.push(0);
vars_mock.push(b'b' as u16);
vars_mock.push(b'a' as u16);
vars_mock.push(b'z' as u16);
vars_mock.push(b'2' as u16);
vars_mock.push(0);
vars_mock.push(0);

let vars = Vars {
names: vars_mock.as_ptr().cast(),
protocol: &ShellMock::new(
Vec::from([cstr16!("foo1"), cstr16!("bar"), cstr16!("baz2")]),
Vec::from([cstr16!("value"), cstr16!("one"), cstr16!("two")]),
),
placeholder: PhantomData,
};
assert_eq!(
vars.collect::<Vec<_>>(),
Vec::from([
(cstr16!("foo1"), Some(cstr16!("value"))),
(cstr16!("bar"), Some(cstr16!("one"))),
(cstr16!("baz2"), Some(cstr16!("two")))
])
);
}
}
Loading