fix: allowing deriving Patch for a struct with generic argument (closes #4163) (#4175)

This commit is contained in:
mskorkowski
2025-08-03 14:28:42 +02:00
committed by GitHub
parent a5e0053bab
commit e2e28ef180

View File

@@ -6,8 +6,8 @@ use syn::{
parse::{Parse, ParseStream, Parser},
punctuated::Punctuated,
token::Comma,
ExprClosure, Field, Fields, Generics, Ident, Index, Meta, Result, Token,
Type, Variant, Visibility, WhereClause,
ExprClosure, Field, Fields, GenericParam, Generics, Ident, Index, Meta,
Result, Token, Type, TypeParam, Variant, Visibility, WhereClause,
};
#[proc_macro_error]
@@ -26,6 +26,103 @@ pub fn derive_patch(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
.into()
}
/// Removes all constraints from generics arguments list.
///
/// # Example
///
/// ```rust,ignore
/// struct Data<
/// 'a,
/// T1: ToString + PatchField,
/// T2: PatchField,
/// T3: 'static + PatchField,
/// T4,
/// >
/// where
/// T3: ToString,
/// T4: ToString + PatchField,
/// {
/// data1: &'a T1,
/// data2: T2,
/// data3: T3,
/// data4: T4,
/// }
/// ```
///
/// Fort the struct above the `[syn::DeriveInput::parse]` will return the instance of [syn::Generics]
/// which will conceptually look like this
///
/// ```text
/// Generics:
/// params:
/// [
/// 'a,
/// T1: ToString + PatchField,
/// T2: PatchField,
/// T3: 'static + PatchField,
/// T4,
/// ]
/// where_clause:
/// [
/// T3: ToString,
/// T4: ToString + PatchField,
/// ]
/// ```
///
/// This method would return a new instance of [syn::Generics] which will conceptually look like this
///
/// ```text
/// Generics:
/// params:
/// [
/// 'a,
/// T1,
/// T2,
/// T3,
/// T4,
/// ]
/// where_clause:
/// []
/// ```
///
/// This is useful when you want to use a generic arguments list for `impl` sections for type definitions.
fn remove_constraint_from_generics(generics: &Generics) -> Generics {
let mut new_generics = generics.clone();
// remove contraints directly placed in the generic arguments list
//
// For generics for `struct A<T: MyTrait>` the `T: MyTrait` becomes `T`
for param in new_generics.params.iter_mut() {
match param {
GenericParam::Lifetime(lifetime) => {
lifetime.bounds.clear(); // remove bounds
lifetime.colon_token = None;
}
GenericParam::Type(type_param) => {
type_param.bounds.clear(); // remove bounds
type_param.colon_token = None;
type_param.eq_token = None;
type_param.default = None;
}
GenericParam::Const(const_param) => {
// replaces const generic with type param without bounds which is basically an `ident` token
*param = GenericParam::Type(TypeParam {
attrs: const_param.attrs.clone(),
ident: const_param.ident.clone(),
colon_token: None,
bounds: Punctuated::new(),
eq_token: None,
default: None,
});
}
}
}
new_generics.where_clause = None; // remove where clause
new_generics
}
struct Model {
vis: Visibility,
name: Ident,
@@ -111,7 +208,9 @@ impl ToTokens for Model {
} = &self;
let any_store_field = Ident::new("AnyStoreField", Span::call_site());
let trait_name = Ident::new(&format!("{name}StoreFields"), name.span());
let clear_generics = remove_constraint_from_generics(generics);
let params = &generics.params;
let clear_params = &clear_generics.params;
let generics_with_orig = quote! { <#any_store_field, #params> };
let where_with_orig = {
generics
@@ -124,17 +223,22 @@ impl ToTokens for Model {
} = &w;
quote! {
#where_token
#any_store_field: #library_path::StoreField<Value = #name #generics>,
#any_store_field: #library_path::StoreField<Value = #name < #clear_params > >,
#predicates
}
})
.unwrap_or_else(|| quote! { where #any_store_field: #library_path::StoreField<Value = #name #generics> })
.unwrap_or_else(|| quote! { where #any_store_field: #library_path::StoreField<Value = #name < #clear_params > > })
};
// define an extension trait that matches this struct
// and implement that trait for all StoreFields
let (trait_fields, read_fields): (Vec<_>, Vec<_>) =
ty.to_field_data(&library_path, generics, &any_store_field, name);
let (trait_fields, read_fields): (Vec<_>, Vec<_>) = ty.to_field_data(
&library_path,
generics,
&clear_generics,
&any_store_field,
name,
);
// read access
tokens.extend(quote! {
@@ -144,7 +248,7 @@ impl ToTokens for Model {
#(#trait_fields)*
}
impl #generics_with_orig #trait_name <AnyStoreField, #params> for AnyStoreField
impl #generics_with_orig #trait_name <AnyStoreField, #clear_params> for AnyStoreField
#where_with_orig
{
#(#read_fields)*
@@ -158,6 +262,7 @@ impl ModelTy {
&self,
library_path: &TokenStream,
generics: &Generics,
clear_generics: &Generics,
any_store_field: &Ident,
name: &Ident,
) -> (Vec<TokenStream>, Vec<TokenStream>) {
@@ -204,6 +309,7 @@ impl ModelTy {
library_path,
ident.as_ref(),
generics,
clear_generics,
any_store_field,
name,
ty,
@@ -215,6 +321,7 @@ impl ModelTy {
library_path,
ident.as_ref(),
generics,
clear_generics,
any_store_field,
name,
ty,
@@ -233,6 +340,7 @@ impl ModelTy {
library_path,
ident,
generics,
clear_generics,
any_store_field,
name,
fields,
@@ -242,6 +350,7 @@ impl ModelTy {
library_path,
ident,
generics,
clear_generics,
any_store_field,
name,
fields,
@@ -260,7 +369,8 @@ fn field_to_tokens(
modes: Option<&[SubfieldMode]>,
library_path: &proc_macro2::TokenStream,
orig_ident: Option<&Ident>,
generics: &Generics,
_generics: &Generics,
clear_generics: &Generics,
any_store_field: &Ident,
name: &Ident,
ty: &Type,
@@ -285,7 +395,7 @@ fn field_to_tokens(
SubfieldMode::Keyed(keyed_by, key_ty) => {
let signature = quote! {
#[track_caller]
fn #ident(self) -> #library_path::KeyedSubfield<#any_store_field, #name #generics, #key_ty, #ty>
fn #ident(self) -> #library_path::KeyedSubfield<#any_store_field, #name #clear_generics, #key_ty, #ty>
};
return if include_body {
quote! {
@@ -318,7 +428,7 @@ fn field_to_tokens(
// default subfield
if include_body {
quote! {
fn #ident(self) -> #library_path::Subfield<#any_store_field, #name #generics, #ty> {
fn #ident(self) -> #library_path::Subfield<#any_store_field, #name #clear_generics, #ty> {
#library_path::Subfield::new(
self,
#idx.into(),
@@ -329,7 +439,7 @@ fn field_to_tokens(
}
} else {
quote! {
fn #ident(self) -> #library_path::Subfield<#any_store_field, #name #generics, #ty>;
fn #ident(self) -> #library_path::Subfield<#any_store_field, #name #clear_generics, #ty>;
}
}
}
@@ -339,7 +449,8 @@ fn variant_to_tokens(
include_body: bool,
library_path: &proc_macro2::TokenStream,
ident: &Ident,
generics: &Generics,
_generics: &Generics,
clear_generics: &Generics,
any_store_field: &Ident,
name: &Ident,
fields: &Fields,
@@ -408,7 +519,7 @@ fn variant_to_tokens(
// default subfield
if include_body {
quote! {
fn #combined_ident(self) -> Option<#library_path::Subfield<#any_store_field, #name #generics, #field_ty>> {
fn #combined_ident(self) -> Option<#library_path::Subfield<#any_store_field, #name #clear_generics, #field_ty>> {
#library_path::StoreField::track_field(&self);
let reader = #library_path::StoreField::reader(&self);
let matches = reader
@@ -440,7 +551,7 @@ fn variant_to_tokens(
}
} else {
quote! {
fn #combined_ident(self) -> Option<#library_path::Subfield<#any_store_field, #name #generics, #field_ty>>;
fn #combined_ident(self) -> Option<#library_path::Subfield<#any_store_field, #name #clear_generics, #field_ty>>;
}
}
}));
@@ -491,7 +602,7 @@ fn variant_to_tokens(
// default subfield
if include_body {
quote! {
fn #combined_ident(self) -> Option<#library_path::Subfield<#any_store_field, #name #generics, #field_ty>> {
fn #combined_ident(self) -> Option<#library_path::Subfield<#any_store_field, #name #clear_generics, #field_ty>> {
#library_path::StoreField::track_field(&self);
let reader = #library_path::StoreField::reader(&self);
let matches = reader
@@ -523,7 +634,7 @@ fn variant_to_tokens(
}
} else {
quote! {
fn #combined_ident(self) -> Option<#library_path::Subfield<#any_store_field, #name #generics, #field_ty>>;
fn #combined_ident(self) -> Option<#library_path::Subfield<#any_store_field, #name #clear_generics, #field_ty>>;
}
}
}));
@@ -665,9 +776,14 @@ impl ToTokens for PatchModel {
}
};
let clear_generics = remove_constraint_from_generics(generics);
let params = clear_generics.params;
let where_clause = &generics.where_clause;
// read access
tokens.extend(quote! {
impl #library_path::PatchField for #name #generics
impl #generics #library_path::PatchField for #name <#params>
#where_clause
{
fn patch_field(
&mut self,