xfrm: avoid RCU warnings around the per-netns netlink socket

net->xfrm.nlsk is used in 2 types of contexts:
 - fully under RCU, with rcu_read_lock + rcu_dereference and a NULL check
 - in the netlink handlers, with requests coming from a userspace socket

In the 2nd case, net->xfrm.nlsk is guaranteed to stay non-NULL and the
object is alive, since we can't enter the netns destruction path while
the user socket holds a reference on the netns.

After adding the __rcu annotation to netns_xfrm.nlsk (which silences
sparse warnings in the RCU users and __net_init code), we need to tell
sparse that the 2nd case is safe. Add a helper for that.

Signed-off-by: Sabrina Dubroca <sd@queasysnail.net>
Reviewed-by: Simon Horman <horms@kernel.org>
Signed-off-by: Steffen Klassert <steffen.klassert@secunet.com>
This commit is contained in:
Sabrina Dubroca
2026-03-09 11:32:43 +01:00
committed by Steffen Klassert
parent 103b4f5b40
commit d87f8bc47f
2 changed files with 18 additions and 9 deletions

View File

@@ -59,7 +59,7 @@ struct netns_xfrm {
struct list_head inexact_bins;
struct sock *nlsk;
struct sock __rcu *nlsk;
struct sock *nlsk_stash;
u32 sysctl_aevent_etime;

View File

@@ -35,6 +35,15 @@
#endif
#include <linux/unaligned.h>
static struct sock *xfrm_net_nlsk(const struct net *net, const struct sk_buff *skb)
{
/* get the source of this request, see netlink_unicast_kernel */
const struct sock *sk = NETLINK_CB(skb).sk;
/* sk is refcounted, the netns stays alive and nlsk with it */
return rcu_dereference_protected(net->xfrm.nlsk, sk->sk_net_refcnt);
}
static int verify_one_alg(struct nlattr **attrs, enum xfrm_attr_type_t type,
struct netlink_ext_ack *extack)
{
@@ -1727,7 +1736,7 @@ static int xfrm_get_spdinfo(struct sk_buff *skb, struct nlmsghdr *nlh,
err = build_spdinfo(r_skb, net, sportid, seq, *flags);
BUG_ON(err < 0);
return nlmsg_unicast(net->xfrm.nlsk, r_skb, sportid);
return nlmsg_unicast(xfrm_net_nlsk(net, skb), r_skb, sportid);
}
static inline unsigned int xfrm_sadinfo_msgsize(void)
@@ -1787,7 +1796,7 @@ static int xfrm_get_sadinfo(struct sk_buff *skb, struct nlmsghdr *nlh,
err = build_sadinfo(r_skb, net, sportid, seq, *flags);
BUG_ON(err < 0);
return nlmsg_unicast(net->xfrm.nlsk, r_skb, sportid);
return nlmsg_unicast(xfrm_net_nlsk(net, skb), r_skb, sportid);
}
static int xfrm_get_sa(struct sk_buff *skb, struct nlmsghdr *nlh,
@@ -1807,7 +1816,7 @@ static int xfrm_get_sa(struct sk_buff *skb, struct nlmsghdr *nlh,
if (IS_ERR(resp_skb)) {
err = PTR_ERR(resp_skb);
} else {
err = nlmsg_unicast(net->xfrm.nlsk, resp_skb, NETLINK_CB(skb).portid);
err = nlmsg_unicast(xfrm_net_nlsk(net, skb), resp_skb, NETLINK_CB(skb).portid);
}
xfrm_state_put(x);
out_noput:
@@ -1898,7 +1907,7 @@ static int xfrm_alloc_userspi(struct sk_buff *skb, struct nlmsghdr *nlh,
}
}
err = nlmsg_unicast(net->xfrm.nlsk, resp_skb, NETLINK_CB(skb).portid);
err = nlmsg_unicast(xfrm_net_nlsk(net, skb), resp_skb, NETLINK_CB(skb).portid);
out:
xfrm_state_put(x);
@@ -2543,7 +2552,7 @@ static int xfrm_get_default(struct sk_buff *skb, struct nlmsghdr *nlh,
r_up->out = net->xfrm.policy_default[XFRM_POLICY_OUT];
nlmsg_end(r_skb, r_nlh);
return nlmsg_unicast(net->xfrm.nlsk, r_skb, portid);
return nlmsg_unicast(xfrm_net_nlsk(net, skb), r_skb, portid);
}
static int xfrm_get_policy(struct sk_buff *skb, struct nlmsghdr *nlh,
@@ -2609,7 +2618,7 @@ static int xfrm_get_policy(struct sk_buff *skb, struct nlmsghdr *nlh,
if (IS_ERR(resp_skb)) {
err = PTR_ERR(resp_skb);
} else {
err = nlmsg_unicast(net->xfrm.nlsk, resp_skb,
err = nlmsg_unicast(xfrm_net_nlsk(net, skb), resp_skb,
NETLINK_CB(skb).portid);
}
} else {
@@ -2782,7 +2791,7 @@ static int xfrm_get_ae(struct sk_buff *skb, struct nlmsghdr *nlh,
err = build_aevent(r_skb, x, &c);
BUG_ON(err < 0);
err = nlmsg_unicast(net->xfrm.nlsk, r_skb, NETLINK_CB(skb).portid);
err = nlmsg_unicast(xfrm_net_nlsk(net, skb), r_skb, NETLINK_CB(skb).portid);
spin_unlock_bh(&x->lock);
xfrm_state_put(x);
return err;
@@ -3486,7 +3495,7 @@ static int xfrm_user_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
goto err;
}
err = netlink_dump_start(net->xfrm.nlsk, skb, nlh, &c);
err = netlink_dump_start(xfrm_net_nlsk(net, skb), skb, nlh, &c);
goto err;
}