mirror of
https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
synced 2026-05-09 13:43:21 -04:00
Merge branch 'bpf-fix-backward-progress-bug-in-bpf_iter_udp'
Martin KaFai Lau says: ==================== bpf: Fix backward progress bug in bpf_iter_udp From: Martin KaFai Lau <martin.lau@kernel.org> This patch set fixes an issue in bpf_iter_udp that makes backward progress and prevents the user space process from finishing. There is a test at the end to reproduce the bug. Please see individual patches for details. v3: - Fixed the iter_fd check and local_port check in the patch 3 selftest. (Yonghong) - Moved jhash2 to test_jhash.h in the patch 3. (Yonghong) - Added explanation in the bucket selection in the patch 3. (Yonghong) v2: - Added patch 1 to fix another bug that goes back to the previous bucket - Simplify the fix in patch 2 to always reset iter->offset to 0 - Add a test case to close all udp_sk in a bucket while in the middle of the iteration. ==================== Link: https://lore.kernel.org/r/20240112190530.3751661-1-martin.lau@linux.dev Signed-off-by: Alexei Starovoitov <ast@kernel.org>
This commit is contained in:
@@ -3137,16 +3137,18 @@ static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
|
||||
struct bpf_udp_iter_state *iter = seq->private;
|
||||
struct udp_iter_state *state = &iter->state;
|
||||
struct net *net = seq_file_net(seq);
|
||||
int resume_bucket, resume_offset;
|
||||
struct udp_table *udptable;
|
||||
unsigned int batch_sks = 0;
|
||||
bool resized = false;
|
||||
struct sock *sk;
|
||||
|
||||
resume_bucket = state->bucket;
|
||||
resume_offset = iter->offset;
|
||||
|
||||
/* The current batch is done, so advance the bucket. */
|
||||
if (iter->st_bucket_done) {
|
||||
if (iter->st_bucket_done)
|
||||
state->bucket++;
|
||||
iter->offset = 0;
|
||||
}
|
||||
|
||||
udptable = udp_get_table_seq(seq, net);
|
||||
|
||||
@@ -3166,19 +3168,19 @@ static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
|
||||
for (; state->bucket <= udptable->mask; state->bucket++) {
|
||||
struct udp_hslot *hslot2 = &udptable->hash2[state->bucket];
|
||||
|
||||
if (hlist_empty(&hslot2->head)) {
|
||||
iter->offset = 0;
|
||||
if (hlist_empty(&hslot2->head))
|
||||
continue;
|
||||
}
|
||||
|
||||
iter->offset = 0;
|
||||
spin_lock_bh(&hslot2->lock);
|
||||
udp_portaddr_for_each_entry(sk, &hslot2->head) {
|
||||
if (seq_sk_match(seq, sk)) {
|
||||
/* Resume from the last iterated socket at the
|
||||
* offset in the bucket before iterator was stopped.
|
||||
*/
|
||||
if (iter->offset) {
|
||||
--iter->offset;
|
||||
if (state->bucket == resume_bucket &&
|
||||
iter->offset < resume_offset) {
|
||||
++iter->offset;
|
||||
continue;
|
||||
}
|
||||
if (iter->end_sk < iter->max_sk) {
|
||||
@@ -3192,9 +3194,6 @@ static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
|
||||
|
||||
if (iter->end_sk)
|
||||
break;
|
||||
|
||||
/* Reset the current bucket's offset before moving to the next bucket. */
|
||||
iter->offset = 0;
|
||||
}
|
||||
|
||||
/* All done: no batch made. */
|
||||
@@ -3213,7 +3212,6 @@ static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
|
||||
/* After allocating a larger batch, retry one more time to grab
|
||||
* the whole bucket.
|
||||
*/
|
||||
state->bucket--;
|
||||
goto again;
|
||||
}
|
||||
done:
|
||||
|
||||
135
tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c
Normal file
135
tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c
Normal file
@@ -0,0 +1,135 @@
|
||||
// SPDX-License-Identifier: GPL-2.0
|
||||
// Copyright (c) 2024 Meta
|
||||
|
||||
#include <test_progs.h>
|
||||
#include "network_helpers.h"
|
||||
#include "sock_iter_batch.skel.h"
|
||||
|
||||
#define TEST_NS "sock_iter_batch_netns"
|
||||
|
||||
static const int nr_soreuse = 4;
|
||||
|
||||
static void do_test(int sock_type, bool onebyone)
|
||||
{
|
||||
int err, i, nread, to_read, total_read, iter_fd = -1;
|
||||
int first_idx, second_idx, indices[nr_soreuse];
|
||||
struct bpf_link *link = NULL;
|
||||
struct sock_iter_batch *skel;
|
||||
int *fds[2] = {};
|
||||
|
||||
skel = sock_iter_batch__open();
|
||||
if (!ASSERT_OK_PTR(skel, "sock_iter_batch__open"))
|
||||
return;
|
||||
|
||||
/* Prepare 2 buckets of sockets in the kernel hashtable */
|
||||
for (i = 0; i < ARRAY_SIZE(fds); i++) {
|
||||
int local_port;
|
||||
|
||||
fds[i] = start_reuseport_server(AF_INET6, sock_type, "::1", 0, 0,
|
||||
nr_soreuse);
|
||||
if (!ASSERT_OK_PTR(fds[i], "start_reuseport_server"))
|
||||
goto done;
|
||||
local_port = get_socket_local_port(*fds[i]);
|
||||
if (!ASSERT_GE(local_port, 0, "get_socket_local_port"))
|
||||
goto done;
|
||||
skel->rodata->ports[i] = ntohs(local_port);
|
||||
}
|
||||
|
||||
err = sock_iter_batch__load(skel);
|
||||
if (!ASSERT_OK(err, "sock_iter_batch__load"))
|
||||
goto done;
|
||||
|
||||
link = bpf_program__attach_iter(sock_type == SOCK_STREAM ?
|
||||
skel->progs.iter_tcp_soreuse :
|
||||
skel->progs.iter_udp_soreuse,
|
||||
NULL);
|
||||
if (!ASSERT_OK_PTR(link, "bpf_program__attach_iter"))
|
||||
goto done;
|
||||
|
||||
iter_fd = bpf_iter_create(bpf_link__fd(link));
|
||||
if (!ASSERT_GE(iter_fd, 0, "bpf_iter_create"))
|
||||
goto done;
|
||||
|
||||
/* Test reading a bucket (either from fds[0] or fds[1]).
|
||||
* Only read "nr_soreuse - 1" number of sockets
|
||||
* from a bucket and leave one socket out from
|
||||
* that bucket on purpose.
|
||||
*/
|
||||
to_read = (nr_soreuse - 1) * sizeof(*indices);
|
||||
total_read = 0;
|
||||
first_idx = -1;
|
||||
do {
|
||||
nread = read(iter_fd, indices, onebyone ? sizeof(*indices) : to_read);
|
||||
if (nread <= 0 || nread % sizeof(*indices))
|
||||
break;
|
||||
total_read += nread;
|
||||
|
||||
if (first_idx == -1)
|
||||
first_idx = indices[0];
|
||||
for (i = 0; i < nread / sizeof(*indices); i++)
|
||||
ASSERT_EQ(indices[i], first_idx, "first_idx");
|
||||
} while (total_read < to_read);
|
||||
ASSERT_EQ(nread, onebyone ? sizeof(*indices) : to_read, "nread");
|
||||
ASSERT_EQ(total_read, to_read, "total_read");
|
||||
|
||||
free_fds(fds[first_idx], nr_soreuse);
|
||||
fds[first_idx] = NULL;
|
||||
|
||||
/* Read the "whole" second bucket */
|
||||
to_read = nr_soreuse * sizeof(*indices);
|
||||
total_read = 0;
|
||||
second_idx = !first_idx;
|
||||
do {
|
||||
nread = read(iter_fd, indices, onebyone ? sizeof(*indices) : to_read);
|
||||
if (nread <= 0 || nread % sizeof(*indices))
|
||||
break;
|
||||
total_read += nread;
|
||||
|
||||
for (i = 0; i < nread / sizeof(*indices); i++)
|
||||
ASSERT_EQ(indices[i], second_idx, "second_idx");
|
||||
} while (total_read <= to_read);
|
||||
ASSERT_EQ(nread, 0, "nread");
|
||||
/* Both so_reuseport ports should be in different buckets, so
|
||||
* total_read must equal to the expected to_read.
|
||||
*
|
||||
* For a very unlikely case, both ports collide at the same bucket,
|
||||
* the bucket offset (i.e. 3) will be skipped and it cannot
|
||||
* expect the to_read number of bytes.
|
||||
*/
|
||||
if (skel->bss->bucket[0] != skel->bss->bucket[1])
|
||||
ASSERT_EQ(total_read, to_read, "total_read");
|
||||
|
||||
done:
|
||||
for (i = 0; i < ARRAY_SIZE(fds); i++)
|
||||
free_fds(fds[i], nr_soreuse);
|
||||
if (iter_fd < 0)
|
||||
close(iter_fd);
|
||||
bpf_link__destroy(link);
|
||||
sock_iter_batch__destroy(skel);
|
||||
}
|
||||
|
||||
void test_sock_iter_batch(void)
|
||||
{
|
||||
struct nstoken *nstoken = NULL;
|
||||
|
||||
SYS_NOFAIL("ip netns del " TEST_NS " &> /dev/null");
|
||||
SYS(done, "ip netns add %s", TEST_NS);
|
||||
SYS(done, "ip -net %s link set dev lo up", TEST_NS);
|
||||
|
||||
nstoken = open_netns(TEST_NS);
|
||||
if (!ASSERT_OK_PTR(nstoken, "open_netns"))
|
||||
goto done;
|
||||
|
||||
if (test__start_subtest("tcp")) {
|
||||
do_test(SOCK_STREAM, true);
|
||||
do_test(SOCK_STREAM, false);
|
||||
}
|
||||
if (test__start_subtest("udp")) {
|
||||
do_test(SOCK_DGRAM, true);
|
||||
do_test(SOCK_DGRAM, false);
|
||||
}
|
||||
close_netns(nstoken);
|
||||
|
||||
done:
|
||||
SYS_NOFAIL("ip netns del " TEST_NS " &> /dev/null");
|
||||
}
|
||||
@@ -72,6 +72,8 @@
|
||||
#define inet_rcv_saddr sk.__sk_common.skc_rcv_saddr
|
||||
#define inet_dport sk.__sk_common.skc_dport
|
||||
|
||||
#define udp_portaddr_hash inet.sk.__sk_common.skc_u16hashes[1]
|
||||
|
||||
#define ir_loc_addr req.__req_common.skc_rcv_saddr
|
||||
#define ir_num req.__req_common.skc_num
|
||||
#define ir_rmt_addr req.__req_common.skc_daddr
|
||||
@@ -85,6 +87,7 @@
|
||||
#define sk_rmem_alloc sk_backlog.rmem_alloc
|
||||
#define sk_refcnt __sk_common.skc_refcnt
|
||||
#define sk_state __sk_common.skc_state
|
||||
#define sk_net __sk_common.skc_net
|
||||
#define sk_v6_daddr __sk_common.skc_v6_daddr
|
||||
#define sk_v6_rcv_saddr __sk_common.skc_v6_rcv_saddr
|
||||
#define sk_flags __sk_common.skc_flags
|
||||
|
||||
91
tools/testing/selftests/bpf/progs/sock_iter_batch.c
Normal file
91
tools/testing/selftests/bpf/progs/sock_iter_batch.c
Normal file
@@ -0,0 +1,91 @@
|
||||
// SPDX-License-Identifier: GPL-2.0
|
||||
// Copyright (c) 2024 Meta
|
||||
|
||||
#include "vmlinux.h"
|
||||
#include <bpf/bpf_helpers.h>
|
||||
#include <bpf/bpf_core_read.h>
|
||||
#include <bpf/bpf_endian.h>
|
||||
#include "bpf_tracing_net.h"
|
||||
#include "bpf_kfuncs.h"
|
||||
|
||||
#define ATTR __always_inline
|
||||
#include "test_jhash.h"
|
||||
|
||||
static bool ipv6_addr_loopback(const struct in6_addr *a)
|
||||
{
|
||||
return (a->s6_addr32[0] | a->s6_addr32[1] |
|
||||
a->s6_addr32[2] | (a->s6_addr32[3] ^ bpf_htonl(1))) == 0;
|
||||
}
|
||||
|
||||
volatile const __u16 ports[2];
|
||||
unsigned int bucket[2];
|
||||
|
||||
SEC("iter/tcp")
|
||||
int iter_tcp_soreuse(struct bpf_iter__tcp *ctx)
|
||||
{
|
||||
struct sock *sk = (struct sock *)ctx->sk_common;
|
||||
struct inet_hashinfo *hinfo;
|
||||
unsigned int hash;
|
||||
struct net *net;
|
||||
int idx;
|
||||
|
||||
if (!sk)
|
||||
return 0;
|
||||
|
||||
sk = bpf_rdonly_cast(sk, bpf_core_type_id_kernel(struct sock));
|
||||
if (sk->sk_family != AF_INET6 ||
|
||||
sk->sk_state != TCP_LISTEN ||
|
||||
!ipv6_addr_loopback(&sk->sk_v6_rcv_saddr))
|
||||
return 0;
|
||||
|
||||
if (sk->sk_num == ports[0])
|
||||
idx = 0;
|
||||
else if (sk->sk_num == ports[1])
|
||||
idx = 1;
|
||||
else
|
||||
return 0;
|
||||
|
||||
/* bucket selection as in inet_lhash2_bucket_sk() */
|
||||
net = sk->sk_net.net;
|
||||
hash = jhash2(sk->sk_v6_rcv_saddr.s6_addr32, 4, net->hash_mix);
|
||||
hash ^= sk->sk_num;
|
||||
hinfo = net->ipv4.tcp_death_row.hashinfo;
|
||||
bucket[idx] = hash & hinfo->lhash2_mask;
|
||||
bpf_seq_write(ctx->meta->seq, &idx, sizeof(idx));
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#define udp_sk(ptr) container_of(ptr, struct udp_sock, inet.sk)
|
||||
|
||||
SEC("iter/udp")
|
||||
int iter_udp_soreuse(struct bpf_iter__udp *ctx)
|
||||
{
|
||||
struct sock *sk = (struct sock *)ctx->udp_sk;
|
||||
struct udp_table *udptable;
|
||||
int idx;
|
||||
|
||||
if (!sk)
|
||||
return 0;
|
||||
|
||||
sk = bpf_rdonly_cast(sk, bpf_core_type_id_kernel(struct sock));
|
||||
if (sk->sk_family != AF_INET6 ||
|
||||
!ipv6_addr_loopback(&sk->sk_v6_rcv_saddr))
|
||||
return 0;
|
||||
|
||||
if (sk->sk_num == ports[0])
|
||||
idx = 0;
|
||||
else if (sk->sk_num == ports[1])
|
||||
idx = 1;
|
||||
else
|
||||
return 0;
|
||||
|
||||
/* bucket selection as in udp_hashslot2() */
|
||||
udptable = sk->sk_net.net->ipv4.udp_table;
|
||||
bucket[idx] = udp_sk(sk)->udp_portaddr_hash & udptable->mask;
|
||||
bpf_seq_write(ctx->meta->seq, &idx, sizeof(idx));
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
char _license[] SEC("license") = "GPL";
|
||||
@@ -69,3 +69,34 @@ u32 jhash(const void *key, u32 length, u32 initval)
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
static __always_inline u32 jhash2(const u32 *k, u32 length, u32 initval)
|
||||
{
|
||||
u32 a, b, c;
|
||||
|
||||
/* Set up the internal state */
|
||||
a = b = c = JHASH_INITVAL + (length<<2) + initval;
|
||||
|
||||
/* Handle most of the key */
|
||||
while (length > 3) {
|
||||
a += k[0];
|
||||
b += k[1];
|
||||
c += k[2];
|
||||
__jhash_mix(a, b, c);
|
||||
length -= 3;
|
||||
k += 3;
|
||||
}
|
||||
|
||||
/* Handle the last 3 u32's */
|
||||
switch (length) {
|
||||
case 3: c += k[2];
|
||||
case 2: b += k[1];
|
||||
case 1: a += k[0];
|
||||
__jhash_final(a, b, c);
|
||||
break;
|
||||
case 0: /* Nothing left to add */
|
||||
break;
|
||||
}
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user