Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 53 additions & 21 deletions derive-encode/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,67 @@ pub fn derive_encode_label_set(input: TokenStream) -> TokenStream {
syn::Fields::Named(syn::FieldsNamed { named, .. }) => named
.into_iter()
.map(|f| {
let attribute = f
.attrs
let ident = f.ident.unwrap();
let ident_string = KEYWORD_IDENTIFIERS
.iter()
.find(|a| a.path().is_ident("prometheus"))
.map(|a| a.parse_args::<syn::Ident>().unwrap().to_string());
let flatten = match attribute.as_deref() {
Some("flatten") => true,
Some(other) => {
panic!("Provided attribute '{other}', but only 'flatten' is supported")
.find(|pair| ident == pair.1)
.map(|pair| pair.0.to_string())
.unwrap_or_else(|| ident.to_string());

let mut flatten = false;
let mut skip_encoding_if_fn: Option<syn::Path> = None;

for attr in f.attrs.iter().filter(|a| a.path().is_ident("prometheus")) {
let result = attr.parse_nested_meta(|meta| {
if meta.path.is_ident("flatten") {
flatten = true;
return Ok(());
}

if meta.path.is_ident("skip_encoding_if") {
let lit: syn::LitStr = meta.value()?.parse()?;
match lit.parse::<syn::Path>() {
Ok(path) => {
skip_encoding_if_fn = Some(path);
Ok(())
}
Err(err) => Err(err),
}?;
return Ok(());
}

Err(meta.error("unsupported #[prometheus(..)] attribute"))
});

if let Err(err) = result {
return err.to_compile_error();
}
None => false,
};
let ident = f.ident.unwrap();
}

if flatten {
quote! {
EncodeLabelSet::encode(&self.#ident, encoder)?;
EncodeLabelSet::encode(&self.#ident, encoder)?;
}
} else if let Some(skip_fn) = skip_encoding_if_fn {
quote! {
if !(#skip_fn(&self.#ident)) {
let mut label_encoder = encoder.encode_label();
let mut label_key_encoder = label_encoder.encode_label_key()?;
EncodeLabelKey::encode(&#ident_string, &mut label_key_encoder)?;

let mut label_value_encoder = label_key_encoder.encode_label_value()?;
EncodeLabelValue::encode(&self.#ident, &mut label_value_encoder)?;
label_value_encoder.finish()?;
}
}
} else {
let ident_string = KEYWORD_IDENTIFIERS
.iter()
.find(|pair| ident == pair.1)
.map(|pair| pair.0.to_string())
.unwrap_or_else(|| ident.to_string());

quote! {
let mut label_encoder = encoder.encode_label();
let mut label_key_encoder = label_encoder.encode_label_key()?;
EncodeLabelKey::encode(&#ident_string, &mut label_key_encoder)?;

let mut label_value_encoder = label_key_encoder.encode_label_value()?;
EncodeLabelValue::encode(&self.#ident, &mut label_value_encoder)?;

label_value_encoder.finish()?;
}
}
Expand All @@ -64,15 +93,18 @@ pub fn derive_encode_label_set(input: TokenStream) -> TokenStream {
}
syn::Fields::Unit => panic!("Can not derive Encode for struct with unit field."),
},
syn::Data::Enum(syn::DataEnum { .. }) => {
syn::Data::Enum(_) => {
panic!("Can not derive Encode for enum.")
}
syn::Data::Union(_) => panic!("Can not derive Encode for union."),
};

let gen = quote! {
impl ::prometheus_client::encoding::EncodeLabelSet for #name {
fn encode(&self, encoder: &mut ::prometheus_client::encoding::LabelSetEncoder) -> ::core::result::Result<(), ::core::fmt::Error> {
fn encode(
&self,
encoder: &mut prometheus_client::encoding::LabelSetEncoder,
) -> std::result::Result<(), std::fmt::Error> {
use ::prometheus_client::encoding::EncodeLabel;
use ::prometheus_client::encoding::EncodeLabelKey;
use ::prometheus_client::encoding::EncodeLabelValue;
Expand Down
55 changes: 55 additions & 0 deletions derive-encode/tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,61 @@ fn flatten() {
assert_eq!(expected, buffer);
}

#[test]
fn skip_encoding_if() {
fn skip_empty_string(s: &String) -> bool {
s.is_empty()
}

fn skip_zero(n: &u64) -> bool {
*n == 0
}

#[derive(EncodeLabelSet, Hash, Clone, Eq, PartialEq, Debug)]
struct Labels {
method: String,
#[prometheus(skip_encoding_if = "skip_empty_string")]
path: String,
#[prometheus(skip_encoding_if = "skip_zero")]
status_code: u64,
user_id: u64,
}

let mut registry = Registry::default();
let family = Family::<Labels, Counter>::default();
registry.register("my_counter", "This is my counter", family.clone());

family
.get_or_create(&Labels {
method: "GET".to_string(),
path: "".to_string(), // This should be skipped
status_code: 0, // This should be skipped
user_id: 123,
})
.inc();

family
.get_or_create(&Labels {
method: "POST".to_string(),
path: "/api/users".to_string(), // This should not be skipped
status_code: 200, // This should not be skipped
user_id: 456,
})
.inc();

let mut buffer = String::new();
encode(&mut buffer, &registry).unwrap();

assert!(buffer.contains("# HELP my_counter This is my counter."));
assert!(buffer.contains("# TYPE my_counter counter"));
assert!(buffer.contains("my_counter_total{method=\"GET\",user_id=\"123\"} 1"));
assert!(buffer.contains("my_counter_total{method=\"POST\",path=\"/api/users\",status_code=\"200\",user_id=\"456\"} 1"));
assert!(buffer.contains("# EOF"));

assert!(!buffer.contains("path=\"\""));
assert!(!buffer.contains("status_code=\"0\""));
}

#[test]
fn build() {
let t = trybuild::TestCases::new();
Expand Down