-
Notifications
You must be signed in to change notification settings - Fork 2k
Expand file tree
/
Copy pathlib.rs
More file actions
122 lines (117 loc) · 3.79 KB
/
lib.rs
File metadata and controls
122 lines (117 loc) · 3.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{Ident, Type};
fn get_type_tip(t: &Type) -> Option<&Ident> {
let syn::Type::Path(path) = t else {
return None;
};
let segment = path.path.segments.last()?;
Some(&segment.ident)
}
/// Allow all fields in the extractor config to be also overrideable by extractor CLI flags
#[proc_macro_attribute]
pub fn extractor_cli_config(_attr: TokenStream, item: TokenStream) -> TokenStream {
let ast = syn::parse_macro_input!(item as syn::ItemStruct);
let name = &ast.ident;
let fields = ast
.fields
.iter()
.map(|f| {
let ty_tip = get_type_tip(&f.ty);
if f.ident.as_ref().is_some_and(|i| i != "inputs")
&& ty_tip.is_some_and(|i| i == "Vec")
{
quote! {
#[serde(deserialize_with="deserialize::deserialize_newline_or_comma_separated_vec")]
#f
}
} else if ty_tip.is_some_and(|i| i == "FxHashMap" || i == "HashMap") {
quote! {
#[serde(deserialize_with="deserialize::deserialize_newline_or_comma_separated_map")]
#f
}
} else {
quote! { #f }
}
})
.collect::<Vec<_>>();
let cli_name = format_ident!("Cli{}", name);
let cli_fields = ast
.fields
.iter()
.map(|f| {
let id = f.ident.as_ref().unwrap();
let ty = &f.ty;
let type_tip = get_type_tip(ty);
if type_tip.is_some_and(|i| i == "bool") {
quote! {
#[arg(long)]
#[serde(skip_serializing_if="<&bool>::not")]
#id: bool
}
} else if type_tip.is_some_and(|i| i == "Option") {
quote! {
#[arg(long)]
#f
}
} else if id == &format_ident!("verbose") {
quote! {
#[arg(long, short, action=clap::ArgAction::Count)]
#[serde(skip_serializing_if="u8::is_zero")]
#id: u8
}
} else if id == &format_ident!("inputs") {
quote! {
#f
}
} else if type_tip.is_some_and(|i| i == "Vec" || i == "FxHashMap" || i == "HashMap") {
quote! {
#[arg(long)]
#id: Option<String>
}
} else {
quote! {
#[arg(long)]
#id: Option<#ty>
}
}
})
.collect::<Vec<_>>();
let debug_fields = ast
.fields
.iter()
.map(|f| {
let id = f.ident.as_ref().unwrap();
if id == &format_ident!("inputs") {
quote! {
.field("number of inputs", &self.#id.len())
}
} else {
quote! {
.field(stringify!(#id), &self.#id)
}
}
})
.collect::<Vec<_>>();
let ret = quote! {
#[serde_with::apply(_ => #[serde(default)])]
#[derive(Deserialize, Default)]
pub struct #name {
#(#fields),*
}
impl Debug for #name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("configuration:")
#(#debug_fields)*
.finish()
}
}
#[serde_with::skip_serializing_none]
#[derive(clap::Parser, Serialize)]
#[command(about, long_about = None)]
struct #cli_name {
#(#cli_fields),*
}
};
ret.into()
}