Merge branch 'add-support-for-so_priority-cmsg'

Anna Emese Nyiri says:

====================
Add support for SO_PRIORITY cmsg

Introduce a new helper function, `sk_set_prio_allowed`,
to centralize the logic for validating priority settings.
Add support for the `SO_PRIORITY` control message,
enabling user-space applications to set socket priority
via control messages (cmsg).
====================

Link: https://patch.msgid.link/20241213084457.45120-1-annaemesenyiri@gmail.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
This commit is contained in:
Jakub Kicinski
2024-12-16 18:11:21 -08:00
23 changed files with 228 additions and 16 deletions

View File

@@ -148,6 +148,8 @@
#define SCM_TS_OPT_ID 81
#define SO_RCVPRIORITY 82
#if !defined(__KERNEL__)
#if __BITS_PER_LONG == 64

View File

@@ -159,6 +159,8 @@
#define SCM_TS_OPT_ID 81
#define SO_RCVPRIORITY 82
#if !defined(__KERNEL__)
#if __BITS_PER_LONG == 64

View File

@@ -140,6 +140,8 @@
#define SCM_TS_OPT_ID 0x404C
#define SO_RCVPRIORITY 0x404D
#if !defined(__KERNEL__)
#if __BITS_PER_LONG == 64

View File

@@ -141,6 +141,8 @@
#define SCM_TS_OPT_ID 0x005a
#define SO_RCVPRIORITY 0x005b
#if !defined(__KERNEL__)

View File

@@ -172,7 +172,7 @@ struct inet_cork {
u8 tx_flags;
__u8 ttl;
__s16 tos;
char priority;
u32 priority;
__u16 gso_size;
u32 ts_opt_id;
u64 transmit_time;

View File

@@ -81,7 +81,6 @@ struct ipcm_cookie {
__u8 protocol;
__u8 ttl;
__s16 tos;
char priority;
__u16 gso_size;
};
@@ -96,6 +95,7 @@ static inline void ipcm_init_sk(struct ipcm_cookie *ipcm,
ipcm_init(ipcm);
ipcm->sockc.mark = READ_ONCE(inet->sk.sk_mark);
ipcm->sockc.priority = READ_ONCE(inet->sk.sk_priority);
ipcm->sockc.tsflags = READ_ONCE(inet->sk.sk_tsflags);
ipcm->oif = READ_ONCE(inet->sk.sk_bound_dev_if);
ipcm->addr = inet->inet_saddr;

View File

@@ -953,6 +953,7 @@ enum sock_flags {
SOCK_XDP, /* XDP is attached */
SOCK_TSTAMP_NEW, /* Indicates 64 bit timestamps always */
SOCK_RCVMARK, /* Receive SO_MARK ancillary data with packet */
SOCK_RCVPRIORITY, /* Receive SO_PRIORITY ancillary data with packet */
};
#define SK_FLAGS_TIMESTAMP ((1UL << SOCK_TIMESTAMP) | (1UL << SOCK_TIMESTAMPING_RX_SOFTWARE))
@@ -1814,13 +1815,15 @@ struct sockcm_cookie {
u32 mark;
u32 tsflags;
u32 ts_opt_id;
u32 priority;
};
static inline void sockcm_init(struct sockcm_cookie *sockc,
const struct sock *sk)
{
*sockc = (struct sockcm_cookie) {
.tsflags = READ_ONCE(sk->sk_tsflags)
.tsflags = READ_ONCE(sk->sk_tsflags),
.priority = READ_ONCE(sk->sk_priority),
};
}
@@ -2658,7 +2661,8 @@ static inline void sock_recv_cmsgs(struct msghdr *msg, struct sock *sk,
{
#define FLAGS_RECV_CMSGS ((1UL << SOCK_RXQ_OVFL) | \
(1UL << SOCK_RCVTSTAMP) | \
(1UL << SOCK_RCVMARK))
(1UL << SOCK_RCVMARK) |\
(1UL << SOCK_RCVPRIORITY))
#define TSFLAGS_ANY (SOF_TIMESTAMPING_SOFTWARE | \
SOF_TIMESTAMPING_RAW_HARDWARE)

View File

@@ -143,6 +143,8 @@
#define SCM_TS_OPT_ID 81
#define SO_RCVPRIORITY 82
#if !defined(__KERNEL__)
#if __BITS_PER_LONG == 64 || (defined(__x86_64__) && defined(__ILP32__))

View File

@@ -962,7 +962,7 @@ static int raw_sendmsg(struct socket *sock, struct msghdr *msg, size_t size)
}
skb->dev = dev;
skb->priority = READ_ONCE(sk->sk_priority);
skb->priority = sockc.priority;
skb->mark = READ_ONCE(sk->sk_mark);
skb->tstamp = sockc.transmit_time;

View File

@@ -454,6 +454,13 @@ static int sock_set_timeout(long *timeo_p, sockptr_t optval, int optlen,
return 0;
}
static bool sk_set_prio_allowed(const struct sock *sk, int val)
{
return ((val >= TC_PRIO_BESTEFFORT && val <= TC_PRIO_INTERACTIVE) ||
sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) ||
sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN));
}
static bool sock_needs_netstamp(const struct sock *sk)
{
switch (sk->sk_family) {
@@ -1193,9 +1200,7 @@ int sk_setsockopt(struct sock *sk, int level, int optname,
/* handle options which do not require locking the socket. */
switch (optname) {
case SO_PRIORITY:
if ((val >= 0 && val <= 6) ||
sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) ||
sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) {
if (sk_set_prio_allowed(sk, val)) {
sock_set_priority(sk, val);
return 0;
}
@@ -1514,6 +1519,10 @@ int sk_setsockopt(struct sock *sk, int level, int optname,
sock_valbool_flag(sk, SOCK_RCVMARK, valbool);
break;
case SO_RCVPRIORITY:
sock_valbool_flag(sk, SOCK_RCVPRIORITY, valbool);
break;
case SO_RXQ_OVFL:
sock_valbool_flag(sk, SOCK_RXQ_OVFL, valbool);
break;
@@ -1942,6 +1951,10 @@ int sk_getsockopt(struct sock *sk, int level, int optname,
v.val = sock_flag(sk, SOCK_RCVMARK);
break;
case SO_RCVPRIORITY:
v.val = sock_flag(sk, SOCK_RCVPRIORITY);
break;
case SO_RXQ_OVFL:
v.val = sock_flag(sk, SOCK_RXQ_OVFL);
break;
@@ -2942,6 +2955,13 @@ int __sock_cmsg_send(struct sock *sk, struct cmsghdr *cmsg,
case SCM_RIGHTS:
case SCM_CREDENTIALS:
break;
case SO_PRIORITY:
if (cmsg->cmsg_len != CMSG_LEN(sizeof(u32)))
return -EINVAL;
if (!sk_set_prio_allowed(sk, *(u32 *)CMSG_DATA(cmsg)))
return -EPERM;
sockc->priority = *(u32 *)CMSG_DATA(cmsg);
break;
default:
return -EINVAL;
}

View File

@@ -1333,7 +1333,7 @@ static int ip_setup_cork(struct sock *sk, struct inet_cork *cork,
cork->ttl = ipc->ttl;
cork->tos = ipc->tos;
cork->mark = ipc->sockc.mark;
cork->priority = ipc->priority;
cork->priority = ipc->sockc.priority;
cork->transmit_time = ipc->sockc.transmit_time;
cork->tx_flags = 0;
sock_tx_timestamp(sk, &ipc->sockc, &cork->tx_flags);
@@ -1470,7 +1470,7 @@ struct sk_buff *__ip_make_skb(struct sock *sk,
ip_options_build(skb, opt, cork->addr, rt);
}
skb->priority = (cork->tos != -1) ? cork->priority: READ_ONCE(sk->sk_priority);
skb->priority = cork->priority;
skb->mark = cork->mark;
if (sk_is_tcp(sk))
skb_set_delivery_time(skb, cork->transmit_time, SKB_CLOCK_MONOTONIC);

View File

@@ -315,7 +315,7 @@ int ip_cmsg_send(struct sock *sk, struct msghdr *msg, struct ipcm_cookie *ipc,
if (val < 0 || val > 255)
return -EINVAL;
ipc->tos = val;
ipc->priority = rt_tos2priority(ipc->tos);
ipc->sockc.priority = rt_tos2priority(ipc->tos);
break;
case IP_PROTOCOL:
if (cmsg->cmsg_len != CMSG_LEN(sizeof(int)))

View File

@@ -358,7 +358,7 @@ static int raw_send_hdrinc(struct sock *sk, struct flowi4 *fl4,
skb_reserve(skb, hlen);
skb->protocol = htons(ETH_P_IP);
skb->priority = READ_ONCE(sk->sk_priority);
skb->priority = sockc->priority;
skb->mark = sockc->mark;
skb_set_delivery_type_by_clockid(skb, sockc->transmit_time, sk->sk_clockid);
skb_dst_set(skb, &rt->dst);

View File

@@ -1401,6 +1401,7 @@ static int ip6_setup_cork(struct sock *sk, struct inet_cork_full *cork,
cork->base.gso_size = ipc6->gso_size;
cork->base.tx_flags = 0;
cork->base.mark = ipc6->sockc.mark;
cork->base.priority = ipc6->sockc.priority;
sock_tx_timestamp(sk, &ipc6->sockc, &cork->base.tx_flags);
if (ipc6->sockc.tsflags & SOCKCM_FLAG_TS_OPT_ID) {
cork->base.flags |= IPCORK_TS_OPT_ID;
@@ -1942,7 +1943,7 @@ struct sk_buff *__ip6_make_skb(struct sock *sk,
hdr->saddr = fl6->saddr;
hdr->daddr = *final_dst;
skb->priority = READ_ONCE(sk->sk_priority);
skb->priority = cork->base.priority;
skb->mark = cork->base.mark;
if (sk_is_tcp(sk))
skb_set_delivery_time(skb, cork->base.transmit_time, SKB_CLOCK_MONOTONIC);

View File

@@ -119,6 +119,7 @@ static int ping_v6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
return -EINVAL;
ipcm6_init_sk(&ipc6, sk);
ipc6.sockc.priority = READ_ONCE(sk->sk_priority);
ipc6.sockc.tsflags = READ_ONCE(sk->sk_tsflags);
ipc6.sockc.mark = READ_ONCE(sk->sk_mark);

View File

@@ -619,7 +619,7 @@ static int rawv6_send_hdrinc(struct sock *sk, struct msghdr *msg, int length,
skb_reserve(skb, hlen);
skb->protocol = htons(ETH_P_IPV6);
skb->priority = READ_ONCE(sk->sk_priority);
skb->priority = sockc->priority;
skb->mark = sockc->mark;
skb_set_delivery_type_by_clockid(skb, sockc->transmit_time, sk->sk_clockid);
@@ -780,6 +780,7 @@ static int rawv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
ipcm6_init(&ipc6);
ipc6.sockc.tsflags = READ_ONCE(sk->sk_tsflags);
ipc6.sockc.mark = fl6.flowi6_mark;
ipc6.sockc.priority = READ_ONCE(sk->sk_priority);
if (sin6) {
if (addr_len < SIN6_LEN_RFC2133)

View File

@@ -1448,6 +1448,7 @@ int udpv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
ipc6.gso_size = READ_ONCE(up->gso_size);
ipc6.sockc.tsflags = READ_ONCE(sk->sk_tsflags);
ipc6.sockc.mark = READ_ONCE(sk->sk_mark);
ipc6.sockc.priority = READ_ONCE(sk->sk_priority);
/* destination address check */
if (sin6) {

View File

@@ -3126,7 +3126,7 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
skb->protocol = proto;
skb->dev = dev;
skb->priority = READ_ONCE(sk->sk_priority);
skb->priority = sockc.priority;
skb->mark = sockc.mark;
skb_set_delivery_type_by_clockid(skb, sockc.transmit_time, sk->sk_clockid);

View File

@@ -1008,12 +1008,23 @@ static void sock_recv_mark(struct msghdr *msg, struct sock *sk,
}
}
static void sock_recv_priority(struct msghdr *msg, struct sock *sk,
struct sk_buff *skb)
{
if (sock_flag(sk, SOCK_RCVPRIORITY) && skb) {
__u32 priority = skb->priority;
put_cmsg(msg, SOL_SOCKET, SO_PRIORITY, sizeof(__u32), &priority);
}
}
void __sock_recv_cmsgs(struct msghdr *msg, struct sock *sk,
struct sk_buff *skb)
{
sock_recv_timestamp(msg, sk, skb);
sock_recv_drops(msg, sk, skb);
sock_recv_mark(msg, sk, skb);
sock_recv_priority(msg, sk, skb);
}
EXPORT_SYMBOL_GPL(__sock_recv_cmsgs);

View File

@@ -126,6 +126,8 @@
#define SCM_TS_OPT_ID 78
#define SO_RCVPRIORITY 79
#if !defined(__KERNEL__)
#if __BITS_PER_LONG == 64 || (defined(__x86_64__) && defined(__ILP32__))

View File

@@ -32,6 +32,7 @@ TEST_PROGS += ioam6.sh
TEST_PROGS += gro.sh
TEST_PROGS += gre_gso.sh
TEST_PROGS += cmsg_so_mark.sh
TEST_PROGS += cmsg_so_priority.sh
TEST_PROGS += cmsg_time.sh cmsg_ipv6.sh
TEST_PROGS += netns-name.sh
TEST_PROGS += nl_netdev.py

View File

@@ -59,6 +59,7 @@ struct options {
unsigned int proto;
} sock;
struct option_cmsg_u32 mark;
struct option_cmsg_u32 priority;
struct {
bool ena;
unsigned int delay;
@@ -97,6 +98,8 @@ static void __attribute__((noreturn)) cs_usage(const char *bin)
"\n"
"\t\t-m val Set SO_MARK with given value\n"
"\t\t-M val Set SO_MARK via setsockopt\n"
"\t\t-P val Set SO_PRIORITY via setsockopt\n"
"\t\t-Q val Set SO_PRIORITY via cmsg\n"
"\t\t-d val Set SO_TXTIME with given delay (usec)\n"
"\t\t-t Enable time stamp reporting\n"
"\t\t-f val Set don't fragment via cmsg\n"
@@ -115,7 +118,7 @@ static void cs_parse_args(int argc, char *argv[])
{
int o;
while ((o = getopt(argc, argv, "46sS:p:P:m:M:n:d:tf:F:c:C:l:L:H:")) != -1) {
while ((o = getopt(argc, argv, "46sS:p:P:m:M:n:d:tf:F:c:C:l:L:H:Q:")) != -1) {
switch (o) {
case 's':
opt.silent_send = true;
@@ -148,6 +151,10 @@ static void cs_parse_args(int argc, char *argv[])
opt.mark.ena = true;
opt.mark.val = atoi(optarg);
break;
case 'Q':
opt.priority.ena = true;
opt.priority.val = atoi(optarg);
break;
case 'M':
opt.sockopt.mark = atoi(optarg);
break;
@@ -252,6 +259,8 @@ cs_write_cmsg(int fd, struct msghdr *msg, char *cbuf, size_t cbuf_sz)
ca_write_cmsg_u32(cbuf, cbuf_sz, &cmsg_len,
SOL_SOCKET, SO_MARK, &opt.mark);
ca_write_cmsg_u32(cbuf, cbuf_sz, &cmsg_len,
SOL_SOCKET, SO_PRIORITY, &opt.priority);
ca_write_cmsg_u32(cbuf, cbuf_sz, &cmsg_len,
SOL_IPV6, IPV6_DONTFRAG, &opt.v6.dontfrag);
ca_write_cmsg_u32(cbuf, cbuf_sz, &cmsg_len,

View File

@@ -0,0 +1,151 @@
#!/bin/bash
# SPDX-License-Identifier: GPL-2.0
source lib.sh
readonly KSFT_SKIP=4
IP4=192.0.2.1/24
TGT4=192.0.2.2
TGT4_RAW=192.0.2.3
IP6=2001:db8::1/64
TGT6=2001:db8::2
TGT6_RAW=2001:db8::3
PORT=1234
TOTAL_TESTS=0
FAILED_TESTS=0
if ! command -v jq &> /dev/null; then
echo "SKIP cmsg_so_priroity.sh test: jq is not installed." >&2
exit "$KSFT_SKIP"
fi
check_result() {
((TOTAL_TESTS++))
if [ "$1" -ne 0 ]; then
((FAILED_TESTS++))
fi
}
cleanup()
{
cleanup_ns $NS
}
trap cleanup EXIT
setup_ns NS
create_filter() {
local handle=$1
local vlan_prio=$2
local ip_type=$3
local proto=$4
local dst_ip=$5
local ip_proto
if [[ "$proto" == "u" ]]; then
ip_proto="udp"
elif [[ "$ip_type" == "ipv4" && "$proto" == "i" ]]; then
ip_proto="icmp"
elif [[ "$ip_type" == "ipv6" && "$proto" == "i" ]]; then
ip_proto="icmpv6"
fi
tc -n $NS filter add dev dummy1 \
egress pref 1 handle "$handle" proto 802.1q \
flower vlan_prio "$vlan_prio" vlan_ethtype "$ip_type" \
dst_ip "$dst_ip" ${ip_proto:+ip_proto $ip_proto} \
action pass
}
ip -n $NS link set dev lo up
ip -n $NS link add name dummy1 up type dummy
ip -n $NS link add link dummy1 name dummy1.10 up type vlan id 10 \
egress-qos-map 0:0 1:1 2:2 3:3 4:4 5:5 6:6 7:7
ip -n $NS address add $IP4 dev dummy1.10
ip -n $NS address add $IP6 dev dummy1.10 nodad
ip netns exec $NS sysctl -wq net.ipv4.ping_group_range='0 2147483647'
ip -n $NS neigh add $TGT4 lladdr 00:11:22:33:44:55 nud permanent \
dev dummy1.10
ip -n $NS neigh add $TGT6 lladdr 00:11:22:33:44:55 nud permanent \
dev dummy1.10
ip -n $NS neigh add $TGT4_RAW lladdr 00:11:22:33:44:66 nud permanent \
dev dummy1.10
ip -n $NS neigh add $TGT6_RAW lladdr 00:11:22:33:44:66 nud permanent \
dev dummy1.10
tc -n $NS qdisc add dev dummy1 clsact
FILTER_COUNTER=10
for i in 4 6; do
for proto in u i r; do
echo "Test IPV$i, prot: $proto"
for priority in {0..7}; do
if [[ $i == 4 && $proto == "r" ]]; then
TGT=$TGT4_RAW
elif [[ $i == 6 && $proto == "r" ]]; then
TGT=$TGT6_RAW
elif [ $i == 4 ]; then
TGT=$TGT4
else
TGT=$TGT6
fi
handle="${FILTER_COUNTER}${priority}"
create_filter $handle $priority ipv$i $proto $TGT
pkts=$(tc -n $NS -j -s filter show dev dummy1 egress \
| jq ".[] | select(.options.handle == ${handle}) | \
.options.actions[0].stats.packets")
if [[ $pkts == 0 ]]; then
check_result 0
else
echo "prio $priority: expected 0, got $pkts"
check_result 1
fi
ip netns exec $NS ./cmsg_sender -$i -Q $priority \
-p $proto $TGT $PORT
pkts=$(tc -n $NS -j -s filter show dev dummy1 egress \
| jq ".[] | select(.options.handle == ${handle}) | \
.options.actions[0].stats.packets")
if [[ $pkts == 1 ]]; then
check_result 0
else
echo "prio $priority -Q: expected 1, got $pkts"
check_result 1
fi
ip netns exec $NS ./cmsg_sender -$i -P $priority \
-p $proto $TGT $PORT
pkts=$(tc -n $NS -j -s filter show dev dummy1 egress \
| jq ".[] | select(.options.handle == ${handle}) | \
.options.actions[0].stats.packets")
if [[ $pkts == 2 ]]; then
check_result 0
else
echo "prio $priority -P: expected 2, got $pkts"
check_result 1
fi
done
FILTER_COUNTER=$((FILTER_COUNTER + 10))
done
done
if [ $FAILED_TESTS -ne 0 ]; then
echo "FAIL - $FAILED_TESTS/$TOTAL_TESTS tests failed"
exit 1
else
echo "OK - All $TOTAL_TESTS tests passed"
exit 0
fi