diff --git a/include/linux/list_nulls.h b/include/linux/list_nulls.h index 5d10ae364b5e..e8c300e06438 100644 --- a/include/linux/list_nulls.h +++ b/include/linux/list_nulls.h @@ -21,8 +21,9 @@ struct hlist_nulls_head { struct hlist_nulls_node { struct hlist_nulls_node *next, **pprev; }; +#define NULLS_MARKER(value) (1UL | (((long)value) << 1)) #define INIT_HLIST_NULLS_HEAD(ptr, nulls) \ - ((ptr)->first = (struct hlist_nulls_node *) (1UL | (((long)nulls) << 1))) + ((ptr)->first = (struct hlist_nulls_node *) NULLS_MARKER(nulls)) #define hlist_nulls_entry(ptr, type, member) container_of(ptr,type,member) /** diff --git a/include/linux/rhashtable.h b/include/linux/rhashtable.h index b93fd89b2e5e..de7cac753b09 100644 --- a/include/linux/rhashtable.h +++ b/include/linux/rhashtable.h @@ -18,16 +18,43 @@ #ifndef _LINUX_RHASHTABLE_H #define _LINUX_RHASHTABLE_H -#include +#include +#include + +/* + * The end of the chain is marked with a special nulls marks which has + * the following format: + * + * +-------+-----------------------------------------------------+-+ + * | Base | Hash |1| + * +-------+-----------------------------------------------------+-+ + * + * Base (4 bits) : Reserved to distinguish between multiple tables. + * Specified via &struct rhashtable_params.nulls_base. + * Hash (27 bits): Full hash (unmasked) of first element added to bucket + * 1 (1 bit) : Nulls marker (always set) + * + * The remaining bits of the next pointer remain unused for now. + */ +#define RHT_BASE_BITS 4 +#define RHT_HASH_BITS 27 +#define RHT_BASE_SHIFT RHT_HASH_BITS struct rhash_head { struct rhash_head __rcu *next; }; -#define INIT_HASH_HEAD(ptr) ((ptr)->next = NULL) - +/** + * struct bucket_table - Table of hash buckets + * @size: Number of hash buckets + * @locks_mask: Mask to apply before accessing locks[] + * @locks: Array of spinlocks protecting individual buckets + * @buckets: size * hash buckets + */ struct bucket_table { size_t size; + unsigned int locks_mask; + spinlock_t *locks; struct rhash_head __rcu *buckets[]; }; @@ -45,11 +72,12 @@ struct rhashtable; * @hash_rnd: Seed to use while hashing * @max_shift: Maximum number of shifts while expanding * @min_shift: Minimum number of shifts while shrinking + * @nulls_base: Base value to generate nulls marker + * @locks_mul: Number of bucket locks to allocate per cpu (default: 128) * @hashfn: Function to hash key * @obj_hashfn: Function to hash object * @grow_decision: If defined, may return true if table should expand * @shrink_decision: If defined, may return true if table should shrink - * @mutex_is_held: Must return true if protecting mutex is held */ struct rhashtable_params { size_t nelem_hint; @@ -59,36 +87,67 @@ struct rhashtable_params { u32 hash_rnd; size_t max_shift; size_t min_shift; + u32 nulls_base; + size_t locks_mul; rht_hashfn_t hashfn; rht_obj_hashfn_t obj_hashfn; bool (*grow_decision)(const struct rhashtable *ht, size_t new_size); bool (*shrink_decision)(const struct rhashtable *ht, size_t new_size); -#ifdef CONFIG_PROVE_LOCKING - int (*mutex_is_held)(void *parent); - void *parent; -#endif }; /** * struct rhashtable - Hash table handle * @tbl: Bucket table + * @future_tbl: Table under construction during expansion/shrinking * @nelems: Number of elements in table * @shift: Current size (1 << shift) * @p: Configuration parameters + * @run_work: Deferred worker to expand/shrink asynchronously + * @mutex: Mutex to protect current/future table swapping + * @being_destroyed: True if table is set up for destruction */ struct rhashtable { struct bucket_table __rcu *tbl; - size_t nelems; + struct bucket_table __rcu *future_tbl; + atomic_t nelems; size_t shift; struct rhashtable_params p; + struct delayed_work run_work; + struct mutex mutex; + bool being_destroyed; }; +static inline unsigned long rht_marker(const struct rhashtable *ht, u32 hash) +{ + return NULLS_MARKER(ht->p.nulls_base + hash); +} + +#define INIT_RHT_NULLS_HEAD(ptr, ht, hash) \ + ((ptr) = (typeof(ptr)) rht_marker(ht, hash)) + +static inline bool rht_is_a_nulls(const struct rhash_head *ptr) +{ + return ((unsigned long) ptr & 1); +} + +static inline unsigned long rht_get_nulls_value(const struct rhash_head *ptr) +{ + return ((unsigned long) ptr) >> 1; +} + #ifdef CONFIG_PROVE_LOCKING -int lockdep_rht_mutex_is_held(const struct rhashtable *ht); +int lockdep_rht_mutex_is_held(struct rhashtable *ht); +int lockdep_rht_bucket_is_held(const struct bucket_table *tbl, u32 hash); #else -static inline int lockdep_rht_mutex_is_held(const struct rhashtable *ht) +static inline int lockdep_rht_mutex_is_held(struct rhashtable *ht) +{ + return 1; +} + +static inline int lockdep_rht_bucket_is_held(const struct bucket_table *tbl, + u32 hash) { return 1; } @@ -96,13 +155,8 @@ static inline int lockdep_rht_mutex_is_held(const struct rhashtable *ht) int rhashtable_init(struct rhashtable *ht, struct rhashtable_params *params); -u32 rhashtable_hashfn(const struct rhashtable *ht, const void *key, u32 len); -u32 rhashtable_obj_hashfn(const struct rhashtable *ht, void *ptr); - void rhashtable_insert(struct rhashtable *ht, struct rhash_head *node); bool rhashtable_remove(struct rhashtable *ht, struct rhash_head *node); -void rhashtable_remove_pprev(struct rhashtable *ht, struct rhash_head *obj, - struct rhash_head __rcu **pprev); bool rht_grow_above_75(const struct rhashtable *ht, size_t new_size); bool rht_shrink_below_30(const struct rhashtable *ht, size_t new_size); @@ -110,11 +164,11 @@ bool rht_shrink_below_30(const struct rhashtable *ht, size_t new_size); int rhashtable_expand(struct rhashtable *ht); int rhashtable_shrink(struct rhashtable *ht); -void *rhashtable_lookup(const struct rhashtable *ht, const void *key); -void *rhashtable_lookup_compare(const struct rhashtable *ht, u32 hash, +void *rhashtable_lookup(struct rhashtable *ht, const void *key); +void *rhashtable_lookup_compare(struct rhashtable *ht, const void *key, bool (*compare)(void *, void *), void *arg); -void rhashtable_destroy(const struct rhashtable *ht); +void rhashtable_destroy(struct rhashtable *ht); #define rht_dereference(p, ht) \ rcu_dereference_protected(p, lockdep_rht_mutex_is_held(ht)) @@ -122,92 +176,144 @@ void rhashtable_destroy(const struct rhashtable *ht); #define rht_dereference_rcu(p, ht) \ rcu_dereference_check(p, lockdep_rht_mutex_is_held(ht)) -#define rht_entry(ptr, type, member) container_of(ptr, type, member) -#define rht_entry_safe(ptr, type, member) \ -({ \ - typeof(ptr) __ptr = (ptr); \ - __ptr ? rht_entry(__ptr, type, member) : NULL; \ -}) +#define rht_dereference_bucket(p, tbl, hash) \ + rcu_dereference_protected(p, lockdep_rht_bucket_is_held(tbl, hash)) -#define rht_next_entry_safe(pos, ht, member) \ -({ \ - pos ? rht_entry_safe(rht_dereference((pos)->member.next, ht), \ - typeof(*(pos)), member) : NULL; \ -}) +#define rht_dereference_bucket_rcu(p, tbl, hash) \ + rcu_dereference_check(p, lockdep_rht_bucket_is_held(tbl, hash)) + +#define rht_entry(tpos, pos, member) \ + ({ tpos = container_of(pos, typeof(*tpos), member); 1; }) + +/** + * rht_for_each_continue - continue iterating over hash chain + * @pos: the &struct rhash_head to use as a loop cursor. + * @head: the previous &struct rhash_head to continue from + * @tbl: the &struct bucket_table + * @hash: the hash value / bucket index + */ +#define rht_for_each_continue(pos, head, tbl, hash) \ + for (pos = rht_dereference_bucket(head, tbl, hash); \ + !rht_is_a_nulls(pos); \ + pos = rht_dereference_bucket((pos)->next, tbl, hash)) /** * rht_for_each - iterate over hash chain - * @pos: &struct rhash_head to use as a loop cursor. - * @head: head of the hash chain (struct rhash_head *) - * @ht: pointer to your struct rhashtable + * @pos: the &struct rhash_head to use as a loop cursor. + * @tbl: the &struct bucket_table + * @hash: the hash value / bucket index */ -#define rht_for_each(pos, head, ht) \ - for (pos = rht_dereference(head, ht); \ - pos; \ - pos = rht_dereference((pos)->next, ht)) +#define rht_for_each(pos, tbl, hash) \ + rht_for_each_continue(pos, (tbl)->buckets[hash], tbl, hash) + +/** + * rht_for_each_entry_continue - continue iterating over hash chain + * @tpos: the type * to use as a loop cursor. + * @pos: the &struct rhash_head to use as a loop cursor. + * @head: the previous &struct rhash_head to continue from + * @tbl: the &struct bucket_table + * @hash: the hash value / bucket index + * @member: name of the &struct rhash_head within the hashable struct. + */ +#define rht_for_each_entry_continue(tpos, pos, head, tbl, hash, member) \ + for (pos = rht_dereference_bucket(head, tbl, hash); \ + (!rht_is_a_nulls(pos)) && rht_entry(tpos, pos, member); \ + pos = rht_dereference_bucket((pos)->next, tbl, hash)) /** * rht_for_each_entry - iterate over hash chain of given type - * @pos: type * to use as a loop cursor. - * @head: head of the hash chain (struct rhash_head *) - * @ht: pointer to your struct rhashtable - * @member: name of the rhash_head within the hashable struct. + * @tpos: the type * to use as a loop cursor. + * @pos: the &struct rhash_head to use as a loop cursor. + * @tbl: the &struct bucket_table + * @hash: the hash value / bucket index + * @member: name of the &struct rhash_head within the hashable struct. */ -#define rht_for_each_entry(pos, head, ht, member) \ - for (pos = rht_entry_safe(rht_dereference(head, ht), \ - typeof(*(pos)), member); \ - pos; \ - pos = rht_next_entry_safe(pos, ht, member)) +#define rht_for_each_entry(tpos, pos, tbl, hash, member) \ + rht_for_each_entry_continue(tpos, pos, (tbl)->buckets[hash], \ + tbl, hash, member) /** * rht_for_each_entry_safe - safely iterate over hash chain of given type - * @pos: type * to use as a loop cursor. - * @n: type * to use for temporary next object storage - * @head: head of the hash chain (struct rhash_head *) - * @ht: pointer to your struct rhashtable - * @member: name of the rhash_head within the hashable struct. + * @tpos: the type * to use as a loop cursor. + * @pos: the &struct rhash_head to use as a loop cursor. + * @next: the &struct rhash_head to use as next in loop cursor. + * @tbl: the &struct bucket_table + * @hash: the hash value / bucket index + * @member: name of the &struct rhash_head within the hashable struct. * * This hash chain list-traversal primitive allows for the looped code to * remove the loop cursor from the list. */ -#define rht_for_each_entry_safe(pos, n, head, ht, member) \ - for (pos = rht_entry_safe(rht_dereference(head, ht), \ - typeof(*(pos)), member), \ - n = rht_next_entry_safe(pos, ht, member); \ - pos; \ - pos = n, \ - n = rht_next_entry_safe(pos, ht, member)) +#define rht_for_each_entry_safe(tpos, pos, next, tbl, hash, member) \ + for (pos = rht_dereference_bucket((tbl)->buckets[hash], tbl, hash), \ + next = !rht_is_a_nulls(pos) ? \ + rht_dereference_bucket(pos->next, tbl, hash) : NULL; \ + (!rht_is_a_nulls(pos)) && rht_entry(tpos, pos, member); \ + pos = next) + +/** + * rht_for_each_rcu_continue - continue iterating over rcu hash chain + * @pos: the &struct rhash_head to use as a loop cursor. + * @head: the previous &struct rhash_head to continue from + * @tbl: the &struct bucket_table + * @hash: the hash value / bucket index + * + * This hash chain list-traversal primitive may safely run concurrently with + * the _rcu mutation primitives such as rhashtable_insert() as long as the + * traversal is guarded by rcu_read_lock(). + */ +#define rht_for_each_rcu_continue(pos, head, tbl, hash) \ + for (({barrier(); }), \ + pos = rht_dereference_bucket_rcu(head, tbl, hash); \ + !rht_is_a_nulls(pos); \ + pos = rcu_dereference_raw(pos->next)) /** * rht_for_each_rcu - iterate over rcu hash chain - * @pos: &struct rhash_head to use as a loop cursor. - * @head: head of the hash chain (struct rhash_head *) - * @ht: pointer to your struct rhashtable + * @pos: the &struct rhash_head to use as a loop cursor. + * @tbl: the &struct bucket_table + * @hash: the hash value / bucket index * * This hash chain list-traversal primitive may safely run concurrently with - * the _rcu fkht mutation primitives such as rht_insert() as long as the + * the _rcu mutation primitives such as rhashtable_insert() as long as the * traversal is guarded by rcu_read_lock(). */ -#define rht_for_each_rcu(pos, head, ht) \ - for (pos = rht_dereference_rcu(head, ht); \ - pos; \ - pos = rht_dereference_rcu((pos)->next, ht)) +#define rht_for_each_rcu(pos, tbl, hash) \ + rht_for_each_rcu_continue(pos, (tbl)->buckets[hash], tbl, hash) + +/** + * rht_for_each_entry_rcu_continue - continue iterating over rcu hash chain + * @tpos: the type * to use as a loop cursor. + * @pos: the &struct rhash_head to use as a loop cursor. + * @head: the previous &struct rhash_head to continue from + * @tbl: the &struct bucket_table + * @hash: the hash value / bucket index + * @member: name of the &struct rhash_head within the hashable struct. + * + * This hash chain list-traversal primitive may safely run concurrently with + * the _rcu mutation primitives such as rhashtable_insert() as long as the + * traversal is guarded by rcu_read_lock(). + */ +#define rht_for_each_entry_rcu_continue(tpos, pos, head, tbl, hash, member) \ + for (({barrier(); }), \ + pos = rht_dereference_bucket_rcu(head, tbl, hash); \ + (!rht_is_a_nulls(pos)) && rht_entry(tpos, pos, member); \ + pos = rht_dereference_bucket_rcu(pos->next, tbl, hash)) /** * rht_for_each_entry_rcu - iterate over rcu hash chain of given type - * @pos: type * to use as a loop cursor. - * @head: head of the hash chain (struct rhash_head *) - * @member: name of the rhash_head within the hashable struct. + * @tpos: the type * to use as a loop cursor. + * @pos: the &struct rhash_head to use as a loop cursor. + * @tbl: the &struct bucket_table + * @hash: the hash value / bucket index + * @member: name of the &struct rhash_head within the hashable struct. * * This hash chain list-traversal primitive may safely run concurrently with - * the _rcu fkht mutation primitives such as rht_insert() as long as the + * the _rcu mutation primitives such as rhashtable_insert() as long as the * traversal is guarded by rcu_read_lock(). */ -#define rht_for_each_entry_rcu(pos, head, member) \ - for (pos = rht_entry_safe(rcu_dereference_raw(head), \ - typeof(*(pos)), member); \ - pos; \ - pos = rht_entry_safe(rcu_dereference_raw((pos)->member.next), \ - typeof(*(pos)), member)) +#define rht_for_each_entry_rcu(tpos, pos, tbl, hash, member) \ + rht_for_each_entry_rcu_continue(tpos, pos, (tbl)->buckets[hash],\ + tbl, hash, member) #endif /* _LINUX_RHASHTABLE_H */ diff --git a/include/linux/spinlock.h b/include/linux/spinlock.h index 262ba4ef9a8e..3e18379dfa6f 100644 --- a/include/linux/spinlock.h +++ b/include/linux/spinlock.h @@ -190,6 +190,8 @@ static inline void do_raw_spin_unlock(raw_spinlock_t *lock) __releases(lock) #ifdef CONFIG_DEBUG_LOCK_ALLOC # define raw_spin_lock_nested(lock, subclass) \ _raw_spin_lock_nested(lock, subclass) +# define raw_spin_lock_bh_nested(lock, subclass) \ + _raw_spin_lock_bh_nested(lock, subclass) # define raw_spin_lock_nest_lock(lock, nest_lock) \ do { \ @@ -205,6 +207,7 @@ static inline void do_raw_spin_unlock(raw_spinlock_t *lock) __releases(lock) # define raw_spin_lock_nested(lock, subclass) \ _raw_spin_lock(((void)(subclass), (lock))) # define raw_spin_lock_nest_lock(lock, nest_lock) _raw_spin_lock(lock) +# define raw_spin_lock_bh_nested(lock, subclass) _raw_spin_lock_bh(lock) #endif #if defined(CONFIG_SMP) || defined(CONFIG_DEBUG_SPINLOCK) @@ -324,6 +327,11 @@ do { \ raw_spin_lock_nested(spinlock_check(lock), subclass); \ } while (0) +#define spin_lock_bh_nested(lock, subclass) \ +do { \ + raw_spin_lock_bh_nested(spinlock_check(lock), subclass);\ +} while (0) + #define spin_lock_nest_lock(lock, nest_lock) \ do { \ raw_spin_lock_nest_lock(spinlock_check(lock), nest_lock); \ diff --git a/include/linux/spinlock_api_smp.h b/include/linux/spinlock_api_smp.h index 42dfab89e740..5344268e6e62 100644 --- a/include/linux/spinlock_api_smp.h +++ b/include/linux/spinlock_api_smp.h @@ -22,6 +22,8 @@ int in_lock_functions(unsigned long addr); void __lockfunc _raw_spin_lock(raw_spinlock_t *lock) __acquires(lock); void __lockfunc _raw_spin_lock_nested(raw_spinlock_t *lock, int subclass) __acquires(lock); +void __lockfunc _raw_spin_lock_bh_nested(raw_spinlock_t *lock, int subclass) + __acquires(lock); void __lockfunc _raw_spin_lock_nest_lock(raw_spinlock_t *lock, struct lockdep_map *map) __acquires(lock); diff --git a/include/linux/spinlock_api_up.h b/include/linux/spinlock_api_up.h index d0d188861ad6..d3afef9d8dbe 100644 --- a/include/linux/spinlock_api_up.h +++ b/include/linux/spinlock_api_up.h @@ -57,6 +57,7 @@ #define _raw_spin_lock(lock) __LOCK(lock) #define _raw_spin_lock_nested(lock, subclass) __LOCK(lock) +#define _raw_spin_lock_bh_nested(lock, subclass) __LOCK(lock) #define _raw_read_lock(lock) __LOCK(lock) #define _raw_write_lock(lock) __LOCK(lock) #define _raw_spin_lock_bh(lock) __LOCK_BH(lock) diff --git a/kernel/locking/spinlock.c b/kernel/locking/spinlock.c index 4b082b5cac9e..db3ccb1dd614 100644 --- a/kernel/locking/spinlock.c +++ b/kernel/locking/spinlock.c @@ -363,6 +363,14 @@ void __lockfunc _raw_spin_lock_nested(raw_spinlock_t *lock, int subclass) } EXPORT_SYMBOL(_raw_spin_lock_nested); +void __lockfunc _raw_spin_lock_bh_nested(raw_spinlock_t *lock, int subclass) +{ + __local_bh_disable_ip(_RET_IP_, SOFTIRQ_LOCK_OFFSET); + spin_acquire(&lock->dep_map, subclass, 0, _RET_IP_); + LOCK_CONTENDED(lock, do_raw_spin_trylock, do_raw_spin_lock); +} +EXPORT_SYMBOL(_raw_spin_lock_bh_nested); + unsigned long __lockfunc _raw_spin_lock_irqsave_nested(raw_spinlock_t *lock, int subclass) { diff --git a/lib/rhashtable.c b/lib/rhashtable.c index 6c3c723e902b..cbad192d3b3d 100644 --- a/lib/rhashtable.c +++ b/lib/rhashtable.c @@ -26,15 +26,47 @@ #define HASH_DEFAULT_SIZE 64UL #define HASH_MIN_SIZE 4UL +#define BUCKET_LOCKS_PER_CPU 128UL + +/* Base bits plus 1 bit for nulls marker */ +#define HASH_RESERVED_SPACE (RHT_BASE_BITS + 1) + +enum { + RHT_LOCK_NORMAL, + RHT_LOCK_NESTED, + RHT_LOCK_NESTED2, +}; + +/* The bucket lock is selected based on the hash and protects mutations + * on a group of hash buckets. + * + * IMPORTANT: When holding the bucket lock of both the old and new table + * during expansions and shrinking, the old bucket lock must always be + * acquired first. + */ +static spinlock_t *bucket_lock(const struct bucket_table *tbl, u32 hash) +{ + return &tbl->locks[hash & tbl->locks_mask]; +} #define ASSERT_RHT_MUTEX(HT) BUG_ON(!lockdep_rht_mutex_is_held(HT)) +#define ASSERT_BUCKET_LOCK(TBL, HASH) \ + BUG_ON(!lockdep_rht_bucket_is_held(TBL, HASH)) #ifdef CONFIG_PROVE_LOCKING -int lockdep_rht_mutex_is_held(const struct rhashtable *ht) +int lockdep_rht_mutex_is_held(struct rhashtable *ht) { - return ht->p.mutex_is_held(ht->p.parent); + return (debug_locks) ? lockdep_is_held(&ht->mutex) : 1; } EXPORT_SYMBOL_GPL(lockdep_rht_mutex_is_held); + +int lockdep_rht_bucket_is_held(const struct bucket_table *tbl, u32 hash) +{ + spinlock_t *lock = bucket_lock(tbl, hash); + + return (debug_locks) ? lockdep_is_held(lock) : 1; +} +EXPORT_SYMBOL_GPL(lockdep_rht_bucket_is_held); #endif static void *rht_obj(const struct rhashtable *ht, const struct rhash_head *he) @@ -42,75 +74,101 @@ static void *rht_obj(const struct rhashtable *ht, const struct rhash_head *he) return (void *) he - ht->p.head_offset; } -static u32 __hashfn(const struct rhashtable *ht, const void *key, - u32 len, u32 hsize) +static u32 rht_bucket_index(const struct bucket_table *tbl, u32 hash) { - u32 h; - - h = ht->p.hashfn(key, len, ht->p.hash_rnd); - - return h & (hsize - 1); + return hash & (tbl->size - 1); } -/** - * rhashtable_hashfn - compute hash for key of given length - * @ht: hash table to compute for - * @key: pointer to key - * @len: length of key - * - * Computes the hash value using the hash function provided in the 'hashfn' - * of struct rhashtable_params. The returned value is guaranteed to be - * smaller than the number of buckets in the hash table. - */ -u32 rhashtable_hashfn(const struct rhashtable *ht, const void *key, u32 len) +static u32 obj_raw_hashfn(const struct rhashtable *ht, const void *ptr) +{ + u32 hash; + + if (unlikely(!ht->p.key_len)) + hash = ht->p.obj_hashfn(ptr, ht->p.hash_rnd); + else + hash = ht->p.hashfn(ptr + ht->p.key_offset, ht->p.key_len, + ht->p.hash_rnd); + + return hash >> HASH_RESERVED_SPACE; +} + +static u32 key_hashfn(struct rhashtable *ht, const void *key, u32 len) { struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht); + u32 hash; - return __hashfn(ht, key, len, tbl->size); + hash = ht->p.hashfn(key, len, ht->p.hash_rnd); + hash >>= HASH_RESERVED_SPACE; + + return rht_bucket_index(tbl, hash); } -EXPORT_SYMBOL_GPL(rhashtable_hashfn); - -static u32 obj_hashfn(const struct rhashtable *ht, const void *ptr, u32 hsize) -{ - if (unlikely(!ht->p.key_len)) { - u32 h; - - h = ht->p.obj_hashfn(ptr, ht->p.hash_rnd); - - return h & (hsize - 1); - } - - return __hashfn(ht, ptr + ht->p.key_offset, ht->p.key_len, hsize); -} - -/** - * rhashtable_obj_hashfn - compute hash for hashed object - * @ht: hash table to compute for - * @ptr: pointer to hashed object - * - * Computes the hash value using the hash function `hashfn` respectively - * 'obj_hashfn' depending on whether the hash table is set up to work with - * a fixed length key. The returned value is guaranteed to be smaller than - * the number of buckets in the hash table. - */ -u32 rhashtable_obj_hashfn(const struct rhashtable *ht, void *ptr) -{ - struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht); - - return obj_hashfn(ht, ptr, tbl->size); -} -EXPORT_SYMBOL_GPL(rhashtable_obj_hashfn); static u32 head_hashfn(const struct rhashtable *ht, - const struct rhash_head *he, u32 hsize) + const struct bucket_table *tbl, + const struct rhash_head *he) { - return obj_hashfn(ht, rht_obj(ht, he), hsize); + return rht_bucket_index(tbl, obj_raw_hashfn(ht, rht_obj(ht, he))); } -static struct bucket_table *bucket_table_alloc(size_t nbuckets) +static struct rhash_head __rcu **bucket_tail(struct bucket_table *tbl, u32 n) +{ + struct rhash_head __rcu **pprev; + + for (pprev = &tbl->buckets[n]; + !rht_is_a_nulls(rht_dereference_bucket(*pprev, tbl, n)); + pprev = &rht_dereference_bucket(*pprev, tbl, n)->next) + ; + + return pprev; +} + +static int alloc_bucket_locks(struct rhashtable *ht, struct bucket_table *tbl) +{ + unsigned int i, size; +#if defined(CONFIG_PROVE_LOCKING) + unsigned int nr_pcpus = 2; +#else + unsigned int nr_pcpus = num_possible_cpus(); +#endif + + nr_pcpus = min_t(unsigned int, nr_pcpus, 32UL); + size = roundup_pow_of_two(nr_pcpus * ht->p.locks_mul); + + /* Never allocate more than one lock per bucket */ + size = min_t(unsigned int, size, tbl->size); + + if (sizeof(spinlock_t) != 0) { +#ifdef CONFIG_NUMA + if (size * sizeof(spinlock_t) > PAGE_SIZE) + tbl->locks = vmalloc(size * sizeof(spinlock_t)); + else +#endif + tbl->locks = kmalloc_array(size, sizeof(spinlock_t), + GFP_KERNEL); + if (!tbl->locks) + return -ENOMEM; + for (i = 0; i < size; i++) + spin_lock_init(&tbl->locks[i]); + } + tbl->locks_mask = size - 1; + + return 0; +} + +static void bucket_table_free(const struct bucket_table *tbl) +{ + if (tbl) + kvfree(tbl->locks); + + kvfree(tbl); +} + +static struct bucket_table *bucket_table_alloc(struct rhashtable *ht, + size_t nbuckets) { struct bucket_table *tbl; size_t size; + int i; size = sizeof(*tbl) + nbuckets * sizeof(tbl->buckets[0]); tbl = kzalloc(size, GFP_KERNEL | __GFP_NOWARN); @@ -122,12 +180,15 @@ static struct bucket_table *bucket_table_alloc(size_t nbuckets) tbl->size = nbuckets; - return tbl; -} + if (alloc_bucket_locks(ht, tbl) < 0) { + bucket_table_free(tbl); + return NULL; + } -static void bucket_table_free(const struct bucket_table *tbl) -{ - kvfree(tbl); + for (i = 0; i < nbuckets; i++) + INIT_RHT_NULLS_HEAD(tbl->buckets[i], ht, i); + + return tbl; } /** @@ -138,7 +199,7 @@ static void bucket_table_free(const struct bucket_table *tbl) bool rht_grow_above_75(const struct rhashtable *ht, size_t new_size) { /* Expand table when exceeding 75% load */ - return ht->nelems > (new_size / 4 * 3); + return atomic_read(&ht->nelems) > (new_size / 4 * 3); } EXPORT_SYMBOL_GPL(rht_grow_above_75); @@ -150,41 +211,59 @@ EXPORT_SYMBOL_GPL(rht_grow_above_75); bool rht_shrink_below_30(const struct rhashtable *ht, size_t new_size) { /* Shrink table beneath 30% load */ - return ht->nelems < (new_size * 3 / 10); + return atomic_read(&ht->nelems) < (new_size * 3 / 10); } EXPORT_SYMBOL_GPL(rht_shrink_below_30); static void hashtable_chain_unzip(const struct rhashtable *ht, const struct bucket_table *new_tbl, - struct bucket_table *old_tbl, size_t n) + struct bucket_table *old_tbl, + size_t old_hash) { struct rhash_head *he, *p, *next; - unsigned int h; + spinlock_t *new_bucket_lock, *new_bucket_lock2 = NULL; + unsigned int new_hash, new_hash2; + + ASSERT_BUCKET_LOCK(old_tbl, old_hash); /* Old bucket empty, no work needed. */ - p = rht_dereference(old_tbl->buckets[n], ht); - if (!p) + p = rht_dereference_bucket(old_tbl->buckets[old_hash], old_tbl, + old_hash); + if (rht_is_a_nulls(p)) return; + new_hash = new_hash2 = head_hashfn(ht, new_tbl, p); + new_bucket_lock = bucket_lock(new_tbl, new_hash); + /* Advance the old bucket pointer one or more times until it * reaches a node that doesn't hash to the same bucket as the * previous node p. Call the previous node p; */ - h = head_hashfn(ht, p, new_tbl->size); - rht_for_each(he, p->next, ht) { - if (head_hashfn(ht, he, new_tbl->size) != h) + rht_for_each_continue(he, p->next, old_tbl, old_hash) { + new_hash2 = head_hashfn(ht, new_tbl, he); + if (new_hash != new_hash2) break; p = he; } - RCU_INIT_POINTER(old_tbl->buckets[n], p->next); + rcu_assign_pointer(old_tbl->buckets[old_hash], p->next); + + spin_lock_bh_nested(new_bucket_lock, RHT_LOCK_NESTED); + + /* If we have encountered an entry that maps to a different bucket in + * the new table, lock down that bucket as well as we might cut off + * the end of the chain. + */ + new_bucket_lock2 = bucket_lock(new_tbl, new_hash); + if (new_bucket_lock != new_bucket_lock2) + spin_lock_bh_nested(new_bucket_lock2, RHT_LOCK_NESTED2); /* Find the subsequent node which does hash to the same * bucket as node P, or NULL if no such node exists. */ - next = NULL; - if (he) { - rht_for_each(he, he->next, ht) { - if (head_hashfn(ht, he, new_tbl->size) == h) { + INIT_RHT_NULLS_HEAD(next, ht, old_hash); + if (!rht_is_a_nulls(he)) { + rht_for_each_continue(he, he->next, old_tbl, old_hash) { + if (head_hashfn(ht, new_tbl, he) == new_hash) { next = he; break; } @@ -194,7 +273,23 @@ static void hashtable_chain_unzip(const struct rhashtable *ht, /* Set p's next pointer to that subsequent node pointer, * bypassing the nodes which do not hash to p's bucket */ - RCU_INIT_POINTER(p->next, next); + rcu_assign_pointer(p->next, next); + + if (new_bucket_lock != new_bucket_lock2) + spin_unlock_bh(new_bucket_lock2); + spin_unlock_bh(new_bucket_lock); +} + +static void link_old_to_new(struct bucket_table *new_tbl, + unsigned int new_hash, struct rhash_head *entry) +{ + spinlock_t *new_bucket_lock; + + new_bucket_lock = bucket_lock(new_tbl, new_hash); + + spin_lock_bh_nested(new_bucket_lock, RHT_LOCK_NESTED); + rcu_assign_pointer(*bucket_tail(new_tbl, new_hash), entry); + spin_unlock_bh(new_bucket_lock); } /** @@ -207,43 +302,59 @@ static void hashtable_chain_unzip(const struct rhashtable *ht, * This function may only be called in a context where it is safe to call * synchronize_rcu(), e.g. not within a rcu_read_lock() section. * - * The caller must ensure that no concurrent table mutations take place. - * It is however valid to have concurrent lookups if they are RCU protected. + * The caller must ensure that no concurrent resizing occurs by holding + * ht->mutex. + * + * It is valid to have concurrent insertions and deletions protected by per + * bucket locks or concurrent RCU protected lookups and traversals. */ int rhashtable_expand(struct rhashtable *ht) { struct bucket_table *new_tbl, *old_tbl = rht_dereference(ht->tbl, ht); struct rhash_head *he; - unsigned int i, h; - bool complete; + spinlock_t *old_bucket_lock; + unsigned int new_hash, old_hash; + bool complete = false; ASSERT_RHT_MUTEX(ht); if (ht->p.max_shift && ht->shift >= ht->p.max_shift) return 0; - new_tbl = bucket_table_alloc(old_tbl->size * 2); + new_tbl = bucket_table_alloc(ht, old_tbl->size * 2); if (new_tbl == NULL) return -ENOMEM; ht->shift++; - /* For each new bucket, search the corresponding old bucket - * for the first entry that hashes to the new bucket, and - * link the new bucket to that entry. Since all the entries - * which will end up in the new bucket appear in the same - * old bucket, this constructs an entirely valid new hash - * table, but with multiple buckets "zipped" together into a - * single imprecise chain. + /* Make insertions go into the new, empty table right away. Deletions + * and lookups will be attempted in both tables until we synchronize. + * The synchronize_rcu() guarantees for the new table to be picked up + * so no new additions go into the old table while we relink. */ - for (i = 0; i < new_tbl->size; i++) { - h = i & (old_tbl->size - 1); - rht_for_each(he, old_tbl->buckets[h], ht) { - if (head_hashfn(ht, he, new_tbl->size) == i) { - RCU_INIT_POINTER(new_tbl->buckets[i], he); + rcu_assign_pointer(ht->future_tbl, new_tbl); + synchronize_rcu(); + + /* For each new bucket, search the corresponding old bucket for the + * first entry that hashes to the new bucket, and link the end of + * newly formed bucket chain (containing entries added to future + * table) to that entry. Since all the entries which will end up in + * the new bucket appear in the same old bucket, this constructs an + * entirely valid new hash table, but with multiple buckets + * "zipped" together into a single imprecise chain. + */ + for (new_hash = 0; new_hash < new_tbl->size; new_hash++) { + old_hash = rht_bucket_index(old_tbl, new_hash); + old_bucket_lock = bucket_lock(old_tbl, old_hash); + + spin_lock_bh(old_bucket_lock); + rht_for_each(he, old_tbl, old_hash) { + if (head_hashfn(ht, new_tbl, he) == new_hash) { + link_old_to_new(new_tbl, new_hash, he); break; } } + spin_unlock_bh(old_bucket_lock); } /* Publish the new table pointer. Lookups may now traverse @@ -253,7 +364,7 @@ int rhashtable_expand(struct rhashtable *ht) rcu_assign_pointer(ht->tbl, new_tbl); /* Unzip interleaved hash chains */ - do { + while (!complete && !ht->being_destroyed) { /* Wait for readers. All new readers will see the new * table, and thus no references to the old table will * remain. @@ -265,12 +376,21 @@ int rhashtable_expand(struct rhashtable *ht) * table): ... */ complete = true; - for (i = 0; i < old_tbl->size; i++) { - hashtable_chain_unzip(ht, new_tbl, old_tbl, i); - if (old_tbl->buckets[i] != NULL) + for (old_hash = 0; old_hash < old_tbl->size; old_hash++) { + struct rhash_head *head; + + old_bucket_lock = bucket_lock(old_tbl, old_hash); + spin_lock_bh(old_bucket_lock); + + hashtable_chain_unzip(ht, new_tbl, old_tbl, old_hash); + head = rht_dereference_bucket(old_tbl->buckets[old_hash], + old_tbl, old_hash); + if (!rht_is_a_nulls(head)) complete = false; + + spin_unlock_bh(old_bucket_lock); } - } while (!complete); + } bucket_table_free(old_tbl); return 0; @@ -284,45 +404,65 @@ EXPORT_SYMBOL_GPL(rhashtable_expand); * This function may only be called in a context where it is safe to call * synchronize_rcu(), e.g. not within a rcu_read_lock() section. * + * The caller must ensure that no concurrent resizing occurs by holding + * ht->mutex. + * * The caller must ensure that no concurrent table mutations take place. * It is however valid to have concurrent lookups if they are RCU protected. + * + * It is valid to have concurrent insertions and deletions protected by per + * bucket locks or concurrent RCU protected lookups and traversals. */ int rhashtable_shrink(struct rhashtable *ht) { - struct bucket_table *ntbl, *tbl = rht_dereference(ht->tbl, ht); - struct rhash_head __rcu **pprev; - unsigned int i; + struct bucket_table *new_tbl, *tbl = rht_dereference(ht->tbl, ht); + spinlock_t *new_bucket_lock, *old_bucket_lock1, *old_bucket_lock2; + unsigned int new_hash; ASSERT_RHT_MUTEX(ht); if (ht->shift <= ht->p.min_shift) return 0; - ntbl = bucket_table_alloc(tbl->size / 2); - if (ntbl == NULL) + new_tbl = bucket_table_alloc(ht, tbl->size / 2); + if (new_tbl == NULL) return -ENOMEM; - ht->shift--; + rcu_assign_pointer(ht->future_tbl, new_tbl); + synchronize_rcu(); - /* Link each bucket in the new table to the first bucket - * in the old table that contains entries which will hash - * to the new bucket. + /* Link the first entry in the old bucket to the end of the + * bucket in the new table. As entries are concurrently being + * added to the new table, lock down the new bucket. As we + * always divide the size in half when shrinking, each bucket + * in the new table maps to exactly two buckets in the old + * table. + * + * As removals can occur concurrently on the old table, we need + * to lock down both matching buckets in the old table. */ - for (i = 0; i < ntbl->size; i++) { - ntbl->buckets[i] = tbl->buckets[i]; + for (new_hash = 0; new_hash < new_tbl->size; new_hash++) { + old_bucket_lock1 = bucket_lock(tbl, new_hash); + old_bucket_lock2 = bucket_lock(tbl, new_hash + new_tbl->size); + new_bucket_lock = bucket_lock(new_tbl, new_hash); - /* Link each bucket in the new table to the first bucket - * in the old table that contains entries which will hash - * to the new bucket. - */ - for (pprev = &ntbl->buckets[i]; *pprev != NULL; - pprev = &rht_dereference(*pprev, ht)->next) - ; - RCU_INIT_POINTER(*pprev, tbl->buckets[i + ntbl->size]); + spin_lock_bh(old_bucket_lock1); + spin_lock_bh_nested(old_bucket_lock2, RHT_LOCK_NESTED); + spin_lock_bh_nested(new_bucket_lock, RHT_LOCK_NESTED2); + + rcu_assign_pointer(*bucket_tail(new_tbl, new_hash), + tbl->buckets[new_hash]); + rcu_assign_pointer(*bucket_tail(new_tbl, new_hash), + tbl->buckets[new_hash + new_tbl->size]); + + spin_unlock_bh(new_bucket_lock); + spin_unlock_bh(old_bucket_lock2); + spin_unlock_bh(old_bucket_lock1); } /* Publish the new, valid hash table */ - rcu_assign_pointer(ht->tbl, ntbl); + rcu_assign_pointer(ht->tbl, new_tbl); + ht->shift--; /* Wait for readers. No new readers will have references to the * old hash table. @@ -335,60 +475,72 @@ int rhashtable_shrink(struct rhashtable *ht) } EXPORT_SYMBOL_GPL(rhashtable_shrink); +static void rht_deferred_worker(struct work_struct *work) +{ + struct rhashtable *ht; + struct bucket_table *tbl; + + ht = container_of(work, struct rhashtable, run_work.work); + mutex_lock(&ht->mutex); + tbl = rht_dereference(ht->tbl, ht); + + if (ht->p.grow_decision && ht->p.grow_decision(ht, tbl->size)) + rhashtable_expand(ht); + else if (ht->p.shrink_decision && ht->p.shrink_decision(ht, tbl->size)) + rhashtable_shrink(ht); + + mutex_unlock(&ht->mutex); +} + /** * rhashtable_insert - insert object into hash hash table * @ht: hash table * @obj: pointer to hash head inside object * - * Will automatically grow the table via rhashtable_expand() if the the - * grow_decision function specified at rhashtable_init() returns true. + * Will take a per bucket spinlock to protect against mutual mutations + * on the same bucket. Multiple insertions may occur in parallel unless + * they map to the same bucket lock. * - * The caller must ensure that no concurrent table mutations occur. It is - * however valid to have concurrent lookups if they are RCU protected. + * It is safe to call this function from atomic context. + * + * Will trigger an automatic deferred table resizing if the size grows + * beyond the watermark indicated by grow_decision() which can be passed + * to rhashtable_init(). */ void rhashtable_insert(struct rhashtable *ht, struct rhash_head *obj) { - struct bucket_table *tbl = rht_dereference(ht->tbl, ht); - u32 hash; + struct bucket_table *tbl; + struct rhash_head *head; + spinlock_t *lock; + unsigned hash; - ASSERT_RHT_MUTEX(ht); + rcu_read_lock(); + + tbl = rht_dereference_rcu(ht->future_tbl, ht); + hash = head_hashfn(ht, tbl, obj); + lock = bucket_lock(tbl, hash); + + spin_lock_bh(lock); + head = rht_dereference_bucket(tbl->buckets[hash], tbl, hash); + if (rht_is_a_nulls(head)) + INIT_RHT_NULLS_HEAD(obj->next, ht, hash); + else + RCU_INIT_POINTER(obj->next, head); - hash = head_hashfn(ht, obj, tbl->size); - RCU_INIT_POINTER(obj->next, tbl->buckets[hash]); rcu_assign_pointer(tbl->buckets[hash], obj); - ht->nelems++; + spin_unlock_bh(lock); - if (ht->p.grow_decision && ht->p.grow_decision(ht, tbl->size)) - rhashtable_expand(ht); + atomic_inc(&ht->nelems); + + /* Only grow the table if no resizing is currently in progress. */ + if (ht->tbl != ht->future_tbl && + ht->p.grow_decision && ht->p.grow_decision(ht, tbl->size)) + schedule_delayed_work(&ht->run_work, 0); + + rcu_read_unlock(); } EXPORT_SYMBOL_GPL(rhashtable_insert); -/** - * rhashtable_remove_pprev - remove object from hash table given previous element - * @ht: hash table - * @obj: pointer to hash head inside object - * @pprev: pointer to previous element - * - * Identical to rhashtable_remove() but caller is alreayd aware of the element - * in front of the element to be deleted. This is in particular useful for - * deletion when combined with walking or lookup. - */ -void rhashtable_remove_pprev(struct rhashtable *ht, struct rhash_head *obj, - struct rhash_head __rcu **pprev) -{ - struct bucket_table *tbl = rht_dereference(ht->tbl, ht); - - ASSERT_RHT_MUTEX(ht); - - RCU_INIT_POINTER(*pprev, obj->next); - ht->nelems--; - - if (ht->p.shrink_decision && - ht->p.shrink_decision(ht, tbl->size)) - rhashtable_shrink(ht); -} -EXPORT_SYMBOL_GPL(rhashtable_remove_pprev); - /** * rhashtable_remove - remove object from hash table * @ht: hash table @@ -406,26 +558,56 @@ EXPORT_SYMBOL_GPL(rhashtable_remove_pprev); */ bool rhashtable_remove(struct rhashtable *ht, struct rhash_head *obj) { - struct bucket_table *tbl = rht_dereference(ht->tbl, ht); + struct bucket_table *tbl; struct rhash_head __rcu **pprev; struct rhash_head *he; - u32 h; + spinlock_t *lock; + unsigned int hash; - ASSERT_RHT_MUTEX(ht); + rcu_read_lock(); + tbl = rht_dereference_rcu(ht->tbl, ht); + hash = head_hashfn(ht, tbl, obj); - h = head_hashfn(ht, obj, tbl->size); + lock = bucket_lock(tbl, hash); + spin_lock_bh(lock); - pprev = &tbl->buckets[h]; - rht_for_each(he, tbl->buckets[h], ht) { +restart: + pprev = &tbl->buckets[hash]; + rht_for_each(he, tbl, hash) { if (he != obj) { pprev = &he->next; continue; } - rhashtable_remove_pprev(ht, he, pprev); + rcu_assign_pointer(*pprev, obj->next); + atomic_dec(&ht->nelems); + + spin_unlock_bh(lock); + + if (ht->tbl != ht->future_tbl && + ht->p.shrink_decision && + ht->p.shrink_decision(ht, tbl->size)) + schedule_delayed_work(&ht->run_work, 0); + + rcu_read_unlock(); + return true; } + if (tbl != rht_dereference_rcu(ht->tbl, ht)) { + spin_unlock_bh(lock); + + tbl = rht_dereference_rcu(ht->tbl, ht); + hash = head_hashfn(ht, tbl, obj); + + lock = bucket_lock(tbl, hash); + spin_lock_bh(lock); + goto restart; + } + + spin_unlock_bh(lock); + rcu_read_unlock(); + return false; } EXPORT_SYMBOL_GPL(rhashtable_remove); @@ -441,25 +623,35 @@ EXPORT_SYMBOL_GPL(rhashtable_remove); * This lookup function may only be used for fixed key hash table (key_len * paramter set). It will BUG() if used inappropriately. * - * Lookups may occur in parallel with hash mutations as long as the lookup is - * guarded by rcu_read_lock(). The caller must take care of this. + * Lookups may occur in parallel with hashtable mutations and resizing. */ -void *rhashtable_lookup(const struct rhashtable *ht, const void *key) +void *rhashtable_lookup(struct rhashtable *ht, const void *key) { - const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht); + const struct bucket_table *tbl, *old_tbl; struct rhash_head *he; - u32 h; + u32 hash; BUG_ON(!ht->p.key_len); - h = __hashfn(ht, key, ht->p.key_len, tbl->size); - rht_for_each_rcu(he, tbl->buckets[h], ht) { + rcu_read_lock(); + old_tbl = rht_dereference_rcu(ht->tbl, ht); + tbl = rht_dereference_rcu(ht->future_tbl, ht); + hash = key_hashfn(ht, key, ht->p.key_len); +restart: + rht_for_each_rcu(he, tbl, rht_bucket_index(tbl, hash)) { if (memcmp(rht_obj(ht, he) + ht->p.key_offset, key, ht->p.key_len)) continue; - return (void *) he - ht->p.head_offset; + rcu_read_unlock(); + return rht_obj(ht, he); } + if (unlikely(tbl != old_tbl)) { + tbl = old_tbl; + goto restart; + } + + rcu_read_unlock(); return NULL; } EXPORT_SYMBOL_GPL(rhashtable_lookup); @@ -467,33 +659,43 @@ EXPORT_SYMBOL_GPL(rhashtable_lookup); /** * rhashtable_lookup_compare - search hash table with compare function * @ht: hash table - * @hash: hash value of desired entry + * @key: the pointer to the key * @compare: compare function, must return true on match * @arg: argument passed on to compare function * * Traverses the bucket chain behind the provided hash value and calls the * specified compare function for each entry. * - * Lookups may occur in parallel with hash mutations as long as the lookup is - * guarded by rcu_read_lock(). The caller must take care of this. + * Lookups may occur in parallel with hashtable mutations and resizing. * * Returns the first entry on which the compare function returned true. */ -void *rhashtable_lookup_compare(const struct rhashtable *ht, u32 hash, +void *rhashtable_lookup_compare(struct rhashtable *ht, const void *key, bool (*compare)(void *, void *), void *arg) { - const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht); + const struct bucket_table *tbl, *old_tbl; struct rhash_head *he; + u32 hash; - if (unlikely(hash >= tbl->size)) - return NULL; + rcu_read_lock(); - rht_for_each_rcu(he, tbl->buckets[hash], ht) { + old_tbl = rht_dereference_rcu(ht->tbl, ht); + tbl = rht_dereference_rcu(ht->future_tbl, ht); + hash = key_hashfn(ht, key, ht->p.key_len); +restart: + rht_for_each_rcu(he, tbl, rht_bucket_index(tbl, hash)) { if (!compare(rht_obj(ht, he), arg)) continue; - return (void *) he - ht->p.head_offset; + rcu_read_unlock(); + return rht_obj(ht, he); } + if (unlikely(tbl != old_tbl)) { + tbl = old_tbl; + goto restart; + } + rcu_read_unlock(); + return NULL; } EXPORT_SYMBOL_GPL(rhashtable_lookup_compare); @@ -525,9 +727,7 @@ static size_t rounded_hashtable_size(struct rhashtable_params *params) * .key_offset = offsetof(struct test_obj, key), * .key_len = sizeof(int), * .hashfn = jhash, - * #ifdef CONFIG_PROVE_LOCKING - * .mutex_is_held = &my_mutex_is_held, - * #endif + * .nulls_base = (1U << RHT_BASE_SHIFT), * }; * * Configuration Example 2: Variable length keys @@ -547,9 +747,6 @@ static size_t rounded_hashtable_size(struct rhashtable_params *params) * .head_offset = offsetof(struct test_obj, node), * .hashfn = jhash, * .obj_hashfn = my_hash_fn, - * #ifdef CONFIG_PROVE_LOCKING - * .mutex_is_held = &my_mutex_is_held, - * #endif * }; */ int rhashtable_init(struct rhashtable *ht, struct rhashtable_params *params) @@ -563,24 +760,38 @@ int rhashtable_init(struct rhashtable *ht, struct rhashtable_params *params) (!params->key_len && !params->obj_hashfn)) return -EINVAL; + if (params->nulls_base && params->nulls_base < (1U << RHT_BASE_SHIFT)) + return -EINVAL; + params->min_shift = max_t(size_t, params->min_shift, ilog2(HASH_MIN_SIZE)); if (params->nelem_hint) size = rounded_hashtable_size(params); - tbl = bucket_table_alloc(size); + memset(ht, 0, sizeof(*ht)); + mutex_init(&ht->mutex); + memcpy(&ht->p, params, sizeof(*params)); + + if (params->locks_mul) + ht->p.locks_mul = roundup_pow_of_two(params->locks_mul); + else + ht->p.locks_mul = BUCKET_LOCKS_PER_CPU; + + tbl = bucket_table_alloc(ht, size); if (tbl == NULL) return -ENOMEM; - memset(ht, 0, sizeof(*ht)); ht->shift = ilog2(tbl->size); - memcpy(&ht->p, params, sizeof(*params)); RCU_INIT_POINTER(ht->tbl, tbl); + RCU_INIT_POINTER(ht->future_tbl, tbl); if (!ht->p.hash_rnd) get_random_bytes(&ht->p.hash_rnd, sizeof(ht->p.hash_rnd)); + if (ht->p.grow_decision || ht->p.shrink_decision) + INIT_DEFERRABLE_WORK(&ht->run_work, rht_deferred_worker); + return 0; } EXPORT_SYMBOL_GPL(rhashtable_init); @@ -593,9 +804,16 @@ EXPORT_SYMBOL_GPL(rhashtable_init); * has to make sure that no resizing may happen by unpublishing the hashtable * and waiting for the quiescent cycle before releasing the bucket array. */ -void rhashtable_destroy(const struct rhashtable *ht) +void rhashtable_destroy(struct rhashtable *ht) { - bucket_table_free(ht->tbl); + ht->being_destroyed = true; + + mutex_lock(&ht->mutex); + + cancel_delayed_work(&ht->run_work); + bucket_table_free(rht_dereference(ht->tbl, ht)); + + mutex_unlock(&ht->mutex); } EXPORT_SYMBOL_GPL(rhashtable_destroy); @@ -610,13 +828,6 @@ EXPORT_SYMBOL_GPL(rhashtable_destroy); #define TEST_PTR ((void *) 0xdeadbeef) #define TEST_NEXPANDS 4 -#ifdef CONFIG_PROVE_LOCKING -static int test_mutex_is_held(void *parent) -{ - return 1; -} -#endif - struct test_obj { void *ptr; int value; @@ -656,6 +867,7 @@ static int __init test_rht_lookup(struct rhashtable *ht) static void test_bucket_stats(struct rhashtable *ht, bool quiet) { unsigned int cnt, rcu_cnt, i, total = 0; + struct rhash_head *pos; struct test_obj *obj; struct bucket_table *tbl; @@ -666,14 +878,14 @@ static void test_bucket_stats(struct rhashtable *ht, bool quiet) if (!quiet) pr_info(" [%#4x/%zu]", i, tbl->size); - rht_for_each_entry_rcu(obj, tbl->buckets[i], node) { + rht_for_each_entry_rcu(obj, pos, tbl, i, node) { cnt++; total++; if (!quiet) pr_cont(" [%p],", obj); } - rht_for_each_entry_rcu(obj, tbl->buckets[i], node) + rht_for_each_entry_rcu(obj, pos, tbl, i, node) rcu_cnt++; if (rcu_cnt != cnt) @@ -685,17 +897,18 @@ static void test_bucket_stats(struct rhashtable *ht, bool quiet) i, tbl->buckets[i], cnt); } - pr_info(" Traversal complete: counted=%u, nelems=%zu, entries=%d\n", - total, ht->nelems, TEST_ENTRIES); + pr_info(" Traversal complete: counted=%u, nelems=%u, entries=%d\n", + total, atomic_read(&ht->nelems), TEST_ENTRIES); - if (total != ht->nelems || total != TEST_ENTRIES) + if (total != atomic_read(&ht->nelems) || total != TEST_ENTRIES) pr_warn("Test failed: Total count mismatch ^^^"); } static int __init test_rhashtable(struct rhashtable *ht) { struct bucket_table *tbl; - struct test_obj *obj, *next; + struct test_obj *obj; + struct rhash_head *pos, *next; int err; unsigned int i; @@ -726,7 +939,9 @@ static int __init test_rhashtable(struct rhashtable *ht) for (i = 0; i < TEST_NEXPANDS; i++) { pr_info(" Table expansion iteration %u...\n", i); + mutex_lock(&ht->mutex); rhashtable_expand(ht); + mutex_unlock(&ht->mutex); rcu_read_lock(); pr_info(" Verifying lookups...\n"); @@ -736,7 +951,9 @@ static int __init test_rhashtable(struct rhashtable *ht) for (i = 0; i < TEST_NEXPANDS; i++) { pr_info(" Table shrinkage iteration %u...\n", i); + mutex_lock(&ht->mutex); rhashtable_shrink(ht); + mutex_unlock(&ht->mutex); rcu_read_lock(); pr_info(" Verifying lookups...\n"); @@ -764,7 +981,7 @@ static int __init test_rhashtable(struct rhashtable *ht) error: tbl = rht_dereference_rcu(ht->tbl, ht); for (i = 0; i < tbl->size; i++) - rht_for_each_entry_safe(obj, next, tbl->buckets[i], ht, node) + rht_for_each_entry_safe(obj, pos, next, tbl, i, node) kfree(obj); return err; @@ -779,9 +996,7 @@ static int __init test_rht_init(void) .key_offset = offsetof(struct test_obj, value), .key_len = sizeof(int), .hashfn = jhash, -#ifdef CONFIG_PROVE_LOCKING - .mutex_is_held = &test_mutex_is_held, -#endif + .nulls_base = (3U << RHT_BASE_SHIFT), .grow_decision = rht_grow_above_75, .shrink_decision = rht_shrink_below_30, }; diff --git a/net/netfilter/nft_hash.c b/net/netfilter/nft_hash.c index 1e316ce4cb5d..75887d7d2c6a 100644 --- a/net/netfilter/nft_hash.c +++ b/net/netfilter/nft_hash.c @@ -33,7 +33,7 @@ static bool nft_hash_lookup(const struct nft_set *set, const struct nft_data *key, struct nft_data *data) { - const struct rhashtable *priv = nft_set_priv(set); + struct rhashtable *priv = nft_set_priv(set); const struct nft_hash_elem *he; he = rhashtable_lookup(priv, key); @@ -83,46 +83,53 @@ static void nft_hash_remove(const struct nft_set *set, const struct nft_set_elem *elem) { struct rhashtable *priv = nft_set_priv(set); - struct rhash_head *he, __rcu **pprev; - - pprev = elem->cookie; - he = rht_dereference((*pprev), priv); - - rhashtable_remove_pprev(priv, he, pprev); + rhashtable_remove(priv, elem->cookie); synchronize_rcu(); - kfree(he); + kfree(elem->cookie); +} + +struct nft_compare_arg { + const struct nft_set *set; + struct nft_set_elem *elem; +}; + +static bool nft_hash_compare(void *ptr, void *arg) +{ + struct nft_hash_elem *he = ptr; + struct nft_compare_arg *x = arg; + + if (!nft_data_cmp(&he->key, &x->elem->key, x->set->klen)) { + x->elem->cookie = he; + x->elem->flags = 0; + if (x->set->flags & NFT_SET_MAP) + nft_data_copy(&x->elem->data, he->data); + + return true; + } + + return false; } static int nft_hash_get(const struct nft_set *set, struct nft_set_elem *elem) { - const struct rhashtable *priv = nft_set_priv(set); - const struct bucket_table *tbl = rht_dereference_rcu(priv->tbl, priv); - struct rhash_head __rcu * const *pprev; - struct nft_hash_elem *he; - u32 h; + struct rhashtable *priv = nft_set_priv(set); + struct nft_compare_arg arg = { + .set = set, + .elem = elem, + }; - h = rhashtable_hashfn(priv, &elem->key, set->klen); - pprev = &tbl->buckets[h]; - rht_for_each_entry_rcu(he, tbl->buckets[h], node) { - if (nft_data_cmp(&he->key, &elem->key, set->klen)) { - pprev = &he->node.next; - continue; - } - - elem->cookie = (void *)pprev; - elem->flags = 0; - if (set->flags & NFT_SET_MAP) - nft_data_copy(&elem->data, he->data); + if (rhashtable_lookup_compare(priv, &elem->key, + &nft_hash_compare, &arg)) return 0; - } + return -ENOENT; } static void nft_hash_walk(const struct nft_ctx *ctx, const struct nft_set *set, struct nft_set_iter *iter) { - const struct rhashtable *priv = nft_set_priv(set); + struct rhashtable *priv = nft_set_priv(set); const struct bucket_table *tbl; const struct nft_hash_elem *he; struct nft_set_elem elem; @@ -130,7 +137,9 @@ static void nft_hash_walk(const struct nft_ctx *ctx, const struct nft_set *set, tbl = rht_dereference_rcu(priv->tbl, priv); for (i = 0; i < tbl->size; i++) { - rht_for_each_entry_rcu(he, tbl->buckets[i], node) { + struct rhash_head *pos; + + rht_for_each_entry_rcu(he, pos, tbl, i, node) { if (iter->count < iter->skip) goto cont; @@ -153,13 +162,6 @@ static unsigned int nft_hash_privsize(const struct nlattr * const nla[]) return sizeof(struct rhashtable); } -#ifdef CONFIG_PROVE_LOCKING -static int lockdep_nfnl_lock_is_held(void *parent) -{ - return lockdep_nfnl_is_held(NFNL_SUBSYS_NFTABLES); -} -#endif - static int nft_hash_init(const struct nft_set *set, const struct nft_set_desc *desc, const struct nlattr * const tb[]) @@ -173,9 +175,6 @@ static int nft_hash_init(const struct nft_set *set, .hashfn = jhash, .grow_decision = rht_grow_above_75, .shrink_decision = rht_shrink_below_30, -#ifdef CONFIG_PROVE_LOCKING - .mutex_is_held = lockdep_nfnl_lock_is_held, -#endif }; return rhashtable_init(priv, ¶ms); @@ -183,18 +182,23 @@ static int nft_hash_init(const struct nft_set *set, static void nft_hash_destroy(const struct nft_set *set) { - const struct rhashtable *priv = nft_set_priv(set); - const struct bucket_table *tbl = priv->tbl; - struct nft_hash_elem *he, *next; + struct rhashtable *priv = nft_set_priv(set); + const struct bucket_table *tbl; + struct nft_hash_elem *he; + struct rhash_head *pos, *next; unsigned int i; + /* Stop an eventual async resizing */ + priv->being_destroyed = true; + mutex_lock(&priv->mutex); + + tbl = rht_dereference(priv->tbl, priv); for (i = 0; i < tbl->size; i++) { - for (he = rht_entry(tbl->buckets[i], struct nft_hash_elem, node); - he != NULL; he = next) { - next = rht_entry(he->node.next, struct nft_hash_elem, node); + rht_for_each_entry_safe(he, pos, next, tbl, i, node) nft_hash_elem_destroy(set, he); - } } + mutex_unlock(&priv->mutex); + rhashtable_destroy(priv); } diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index 84ea76ca3f1f..298e1df7132a 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -97,12 +97,12 @@ static int netlink_dump(struct sock *sk); static void netlink_skb_destructor(struct sk_buff *skb); /* nl_table locking explained: - * Lookup and traversal are protected with nl_sk_hash_lock or nl_table_lock - * combined with an RCU read-side lock. Insertion and removal are protected - * with nl_sk_hash_lock while using RCU list modification primitives and may - * run in parallel to nl_table_lock protected lookups. Destruction of the - * Netlink socket may only occur *after* nl_table_lock has been acquired - * either during or after the socket has been removed from the list. + * Lookup and traversal are protected with an RCU read-side lock. Insertion + * and removal are protected with nl_sk_hash_lock while using RCU list + * modification primitives and may run in parallel to RCU protected lookups. + * Destruction of the Netlink socket may only occur *after* nl_table_lock has + * been acquired * either during or after the socket has been removed from + * the list and after an RCU grace period. */ DEFINE_RWLOCK(nl_table_lock); EXPORT_SYMBOL_GPL(nl_table_lock); @@ -114,15 +114,6 @@ static atomic_t nl_table_users = ATOMIC_INIT(0); DEFINE_MUTEX(nl_sk_hash_lock); EXPORT_SYMBOL_GPL(nl_sk_hash_lock); -#ifdef CONFIG_PROVE_LOCKING -static int lockdep_nl_sk_hash_is_held(void *parent) -{ - if (debug_locks) - return lockdep_is_held(&nl_sk_hash_lock) || lockdep_is_held(&nl_table_lock); - return 1; -} -#endif - static ATOMIC_NOTIFIER_HEAD(netlink_chain); static DEFINE_SPINLOCK(netlink_tap_lock); @@ -1002,11 +993,8 @@ static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid, .net = net, .portid = portid, }; - u32 hash; - hash = rhashtable_hashfn(&table->hash, &portid, sizeof(portid)); - - return rhashtable_lookup_compare(&table->hash, hash, + return rhashtable_lookup_compare(&table->hash, &portid, &netlink_compare, &arg); } @@ -1015,13 +1003,11 @@ static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid) struct netlink_table *table = &nl_table[protocol]; struct sock *sk; - read_lock(&nl_table_lock); rcu_read_lock(); sk = __netlink_lookup(table, portid, net); if (sk) sock_hold(sk); rcu_read_unlock(); - read_unlock(&nl_table_lock); return sk; } @@ -1066,7 +1052,8 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid) goto err; err = -ENOMEM; - if (BITS_PER_LONG > 32 && unlikely(table->hash.nelems >= UINT_MAX)) + if (BITS_PER_LONG > 32 && + unlikely(atomic_read(&table->hash.nelems) >= UINT_MAX)) goto err; nlk_sk(sk)->portid = portid; @@ -1194,6 +1181,13 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol, goto out; } +static void deferred_put_nlk_sk(struct rcu_head *head) +{ + struct netlink_sock *nlk = container_of(head, struct netlink_sock, rcu); + + sock_put(&nlk->sk); +} + static int netlink_release(struct socket *sock) { struct sock *sk = sock->sk; @@ -1259,7 +1253,7 @@ static int netlink_release(struct socket *sock) local_bh_disable(); sock_prot_inuse_add(sock_net(sk), &netlink_proto, -1); local_bh_enable(); - sock_put(sk); + call_rcu(&nlk->rcu, deferred_put_nlk_sk); return 0; } @@ -1274,7 +1268,6 @@ static int netlink_autobind(struct socket *sock) retry: cond_resched(); - netlink_table_grab(); rcu_read_lock(); if (__netlink_lookup(table, portid, net)) { /* Bind collision, search negative portid values. */ @@ -1282,11 +1275,9 @@ static int netlink_autobind(struct socket *sock) if (rover > -4097) rover = -4097; rcu_read_unlock(); - netlink_table_ungrab(); goto retry; } rcu_read_unlock(); - netlink_table_ungrab(); err = netlink_insert(sk, net, portid); if (err == -EADDRINUSE) @@ -2901,7 +2892,9 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos) const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht); for (j = 0; j < tbl->size; j++) { - rht_for_each_entry_rcu(nlk, tbl->buckets[j], node) { + struct rhash_head *node; + + rht_for_each_entry_rcu(nlk, node, tbl, j, node) { s = (struct sock *)nlk; if (sock_net(s) != seq_file_net(seq)) @@ -2919,9 +2912,8 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos) } static void *netlink_seq_start(struct seq_file *seq, loff_t *pos) - __acquires(nl_table_lock) __acquires(RCU) + __acquires(RCU) { - read_lock(&nl_table_lock); rcu_read_lock(); return *pos ? netlink_seq_socket_idx(seq, *pos - 1) : SEQ_START_TOKEN; } @@ -2929,6 +2921,8 @@ static void *netlink_seq_start(struct seq_file *seq, loff_t *pos) static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) { struct rhashtable *ht; + const struct bucket_table *tbl; + struct rhash_head *node; struct netlink_sock *nlk; struct nl_seq_iter *iter; struct net *net; @@ -2945,17 +2939,17 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) i = iter->link; ht = &nl_table[i].hash; - rht_for_each_entry(nlk, nlk->node.next, ht, node) + tbl = rht_dereference_rcu(ht->tbl, ht); + rht_for_each_entry_rcu_continue(nlk, node, nlk->node.next, tbl, iter->hash_idx, node) if (net_eq(sock_net((struct sock *)nlk), net)) return nlk; j = iter->hash_idx + 1; do { - const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht); for (; j < tbl->size; j++) { - rht_for_each_entry(nlk, tbl->buckets[j], ht, node) { + rht_for_each_entry_rcu(nlk, node, tbl, j, node) { if (net_eq(sock_net((struct sock *)nlk), net)) { iter->link = i; iter->hash_idx = j; @@ -2971,10 +2965,9 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) } static void netlink_seq_stop(struct seq_file *seq, void *v) - __releases(RCU) __releases(nl_table_lock) + __releases(RCU) { rcu_read_unlock(); - read_unlock(&nl_table_lock); } @@ -3121,9 +3114,6 @@ static int __init netlink_proto_init(void) .max_shift = 16, /* 64K */ .grow_decision = rht_grow_above_75, .shrink_decision = rht_shrink_below_30, -#ifdef CONFIG_PROVE_LOCKING - .mutex_is_held = lockdep_nl_sk_hash_is_held, -#endif }; if (err != 0) diff --git a/net/netlink/af_netlink.h b/net/netlink/af_netlink.h index f123a88496f8..fd96fa76202a 100644 --- a/net/netlink/af_netlink.h +++ b/net/netlink/af_netlink.h @@ -50,6 +50,7 @@ struct netlink_sock { #endif /* CONFIG_NETLINK_MMAP */ struct rhash_head node; + struct rcu_head rcu; }; static inline struct netlink_sock *nlk_sk(struct sock *sk) diff --git a/net/netlink/diag.c b/net/netlink/diag.c index de8c74a3c061..fcca36d81a62 100644 --- a/net/netlink/diag.c +++ b/net/netlink/diag.c @@ -113,7 +113,9 @@ static int __netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, req = nlmsg_data(cb->nlh); for (i = 0; i < htbl->size; i++) { - rht_for_each_entry(nlsk, htbl->buckets[i], ht, node) { + struct rhash_head *pos; + + rht_for_each_entry(nlsk, pos, htbl, i, node) { sk = (struct sock *)nlsk; if (!net_eq(sock_net(sk), net))