diff --git a/src/parse/generics.rs b/src/parse/generics.rs index f393eaf..1d38093 100644 --- a/src/parse/generics.rs +++ b/src/parse/generics.rs @@ -360,7 +360,9 @@ fn test_generics_try_take() { dbg!(&generics); assert_eq!(generics.len(), 2); assert_eq!(generics[0].ident(), "A"); + assert_eq!(generics[0].constraints().len(), 5); assert_eq!(generics[1].ident(), "B"); + assert_eq!(generics[1].constraints().len(), 0); let stream = &mut token_stream("struct Bar"); let (data_type, ident) = super::DataType::take(stream).unwrap(); @@ -376,6 +378,40 @@ fn test_generics_try_take() { } else { panic!("Expected simple generic, got {:?}", generics[0]); } + + let stream = &mut token_stream("struct Bar"); + let (data_type, ident) = super::DataType::take(stream).unwrap(); + assert_eq!(data_type, super::DataType::Struct); + assert_eq!(ident, "Bar"); + let generics = Generics::try_take(stream).unwrap().unwrap(); + dbg!(&generics); + assert_eq!(generics.len(), 1); + if let Generic::Generic(generic) = &generics[0] { + assert_eq!(generic.ident, "A"); + assert_eq!(generic.constraints.len(), 1); + assert_eq!(generic.constraints[0].to_string(), "Debug"); + assert_eq!(generic.default_value.len(), 1); + assert_eq!(generic.default_value[0].to_string(), "()"); + } else { + panic!("Expected simple generic, got {:?}", generics[0]); + } + + let stream = &mut token_stream("struct Bar"); + let (data_type, ident) = super::DataType::take(stream).unwrap(); + assert_eq!(data_type, super::DataType::Struct); + assert_eq!(ident, "Bar"); + let generics = Generics::try_take(stream).unwrap().unwrap(); + dbg!(&generics); + assert_eq!(generics.len(), 1); + if let Generic::Const(generic) = &generics[0] { + assert_eq!(generic.ident, "S"); + assert_eq!(generic.constraints.len(), 1); + assert_eq!(generic.constraints[0].to_string(), "usize"); + assert_eq!(generic.default_value.len(), 1); + assert_eq!(generic.default_value[0].to_string(), "8"); + } else { + panic!("Expected simple generic, got {:?}", generics[0]); + } } /// a lifetime generic parameter, e.g. `struct Foo<'a> { ... }` @@ -445,22 +481,31 @@ pub struct SimpleGeneric { pub default_value: Vec, } +fn try_as_punct_char(tt: &TokenTree) -> Option { + if let TokenTree::Punct(punct) = tt { + Some(punct.as_char()) + } else { + None + } +} + impl SimpleGeneric { pub(crate) fn take(input: &mut Peekable>) -> Result { let ident = assume_ident(input.next()); let mut constraints = Vec::new(); let mut default_value = Vec::new(); - if let Some(TokenTree::Punct(punct)) = input.peek() { - let punct_char = punct.as_char(); - if punct_char == ':' { - assume_punct(input.next(), ':'); - constraints = read_tokens_until_punct(input, &['>', ','])?; - } - if punct_char == '=' { - assume_punct(input.next(), '='); - default_value = read_tokens_until_punct(input, &['>', ','])?; - } + + let mut punct_char = input.peek().and_then(try_as_punct_char); + if punct_char == Some(':') { + assume_punct(input.next(), ':'); + constraints = read_tokens_until_punct(input, &['>', ',', '='])?; + punct_char = input.peek().and_then(try_as_punct_char); } + if punct_char == Some('=') { + assume_punct(input.next(), '='); + default_value = read_tokens_until_punct(input, &['>', ','])?; + } + Ok(Self { ident, constraints, @@ -483,6 +528,8 @@ pub struct ConstGeneric { pub ident: Ident, /// The "constraints" (type) of this generic, e.g. the `usize` from `const N: usize` pub constraints: Vec, + /// The default value of this generic, e.g. `const S = 8` + pub default_value: Vec, } impl ConstGeneric { @@ -490,16 +537,24 @@ impl ConstGeneric { let const_token = assume_ident(input.next()); let ident = assume_ident(input.next()); let mut constraints = Vec::new(); - if let Some(TokenTree::Punct(punct)) = input.peek() { - if punct.as_char() == ':' { - assume_punct(input.next(), ':'); - constraints = read_tokens_until_punct(input, &['>', ','])?; - } + let mut default_value = Vec::new(); + + let mut punct_char = input.peek().and_then(try_as_punct_char); + if punct_char == Some(':') { + assume_punct(input.next(), ':'); + constraints = read_tokens_until_punct(input, &['>', ',', '='])?; + punct_char = input.peek().and_then(try_as_punct_char); + } + if punct_char == Some('=') { + assume_punct(input.next(), '='); + default_value = read_tokens_until_punct(input, &['>', ','])?; } + Ok(Self { const_token, ident, constraints, + default_value, }) } } diff --git a/test/src/main.rs b/test/src/main.rs index a96022c..d25c90b 100644 --- a/test/src/main.rs +++ b/test/src/main.rs @@ -1,3 +1,5 @@ +use std::fmt::Debug; + bitflags::bitflags! { #[derive(virtue_test_derive::RetHi)] pub struct Foo: u8 { @@ -7,7 +9,7 @@ bitflags::bitflags! { } #[derive(virtue_test_derive::RetHi)] -pub struct DefaultGeneric { +pub struct DefaultGeneric { pub t: T, }