diff --git a/include/net/netmem.h b/include/net/netmem.h index 2113a197abb3..a96b3e5e5574 100644 --- a/include/net/netmem.h +++ b/include/net/netmem.h @@ -401,8 +401,24 @@ static inline bool net_is_devmem_iov(const struct net_iov *niov) } #endif -void get_netmem(netmem_ref netmem); -void put_netmem(netmem_ref netmem); +void __get_netmem(netmem_ref netmem); +void __put_netmem(netmem_ref netmem); + +static __always_inline void get_netmem(netmem_ref netmem) +{ + if (netmem_is_net_iov(netmem)) + __get_netmem(netmem); + else + get_page(netmem_to_page(netmem)); +} + +static __always_inline void put_netmem(netmem_ref netmem) +{ + if (netmem_is_net_iov(netmem)) + __put_netmem(netmem); + else + put_page(netmem_to_page(netmem)); +} #define netmem_dma_unmap_addr_set(NETMEM, PTR, ADDR_NAME, VAL) \ do { \ diff --git a/net/core/skbuff.c b/net/core/skbuff.c index 1a84c5a3c446..4d3920e5b141 100644 --- a/net/core/skbuff.c +++ b/net/core/skbuff.c @@ -7423,31 +7423,20 @@ bool csum_and_copy_from_iter_full(void *addr, size_t bytes, } EXPORT_SYMBOL(csum_and_copy_from_iter_full); -void get_netmem(netmem_ref netmem) +void __get_netmem(netmem_ref netmem) { - struct net_iov *niov; + struct net_iov *niov = netmem_to_net_iov(netmem); - if (netmem_is_net_iov(netmem)) { - niov = netmem_to_net_iov(netmem); - if (net_is_devmem_iov(niov)) - net_devmem_get_net_iov(netmem_to_net_iov(netmem)); - return; - } - get_page(netmem_to_page(netmem)); + if (net_is_devmem_iov(niov)) + net_devmem_get_net_iov(netmem_to_net_iov(netmem)); } -EXPORT_SYMBOL(get_netmem); +EXPORT_SYMBOL(__get_netmem); -void put_netmem(netmem_ref netmem) +void __put_netmem(netmem_ref netmem) { - struct net_iov *niov; + struct net_iov *niov = netmem_to_net_iov(netmem); - if (netmem_is_net_iov(netmem)) { - niov = netmem_to_net_iov(netmem); - if (net_is_devmem_iov(niov)) - net_devmem_put_net_iov(netmem_to_net_iov(netmem)); - return; - } - - put_page(netmem_to_page(netmem)); + if (net_is_devmem_iov(niov)) + net_devmem_put_net_iov(netmem_to_net_iov(netmem)); } -EXPORT_SYMBOL(put_netmem); +EXPORT_SYMBOL(__put_netmem);