diff --git a/drivers/infiniband/ulp/rtrs/rtrs-srv.c b/drivers/infiniband/ulp/rtrs/rtrs-srv.c index d6f93601712e..1cb778aff3c5 100644 --- a/drivers/infiniband/ulp/rtrs/rtrs-srv.c +++ b/drivers/infiniband/ulp/rtrs/rtrs-srv.c @@ -1328,51 +1328,6 @@ static void rtrs_srv_dev_release(struct device *dev) kfree(srv); } -static struct rtrs_srv *__alloc_srv(struct rtrs_srv_ctx *ctx, - const uuid_t *paths_uuid) -{ - struct rtrs_srv *srv; - int i; - - srv = kzalloc(sizeof(*srv), GFP_KERNEL); - if (!srv) - return NULL; - - refcount_set(&srv->refcount, 1); - INIT_LIST_HEAD(&srv->paths_list); - mutex_init(&srv->paths_mutex); - mutex_init(&srv->paths_ev_mutex); - uuid_copy(&srv->paths_uuid, paths_uuid); - srv->queue_depth = sess_queue_depth; - srv->ctx = ctx; - device_initialize(&srv->dev); - srv->dev.release = rtrs_srv_dev_release; - - srv->chunks = kcalloc(srv->queue_depth, sizeof(*srv->chunks), - GFP_KERNEL); - if (!srv->chunks) - goto err_free_srv; - - for (i = 0; i < srv->queue_depth; i++) { - srv->chunks[i] = mempool_alloc(chunk_pool, GFP_KERNEL); - if (!srv->chunks[i]) - goto err_free_chunks; - } - list_add(&srv->ctx_list, &ctx->srv_list); - - return srv; - -err_free_chunks: - while (i--) - mempool_free(srv->chunks[i], chunk_pool); - kfree(srv->chunks); - -err_free_srv: - kfree(srv); - - return NULL; -} - static void free_srv(struct rtrs_srv *srv) { int i; @@ -1387,32 +1342,61 @@ static void free_srv(struct rtrs_srv *srv) put_device(&srv->dev); } -static inline struct rtrs_srv *__find_srv_and_get(struct rtrs_srv_ctx *ctx, - const uuid_t *paths_uuid) -{ - struct rtrs_srv *srv; - - list_for_each_entry(srv, &ctx->srv_list, ctx_list) { - if (uuid_equal(&srv->paths_uuid, paths_uuid) && - refcount_inc_not_zero(&srv->refcount)) - return srv; - } - - return NULL; -} - static struct rtrs_srv *get_or_create_srv(struct rtrs_srv_ctx *ctx, const uuid_t *paths_uuid) { struct rtrs_srv *srv; + int i; mutex_lock(&ctx->srv_mutex); - srv = __find_srv_and_get(ctx, paths_uuid); - if (!srv) - srv = __alloc_srv(ctx, paths_uuid); + list_for_each_entry(srv, &ctx->srv_list, ctx_list) { + if (uuid_equal(&srv->paths_uuid, paths_uuid) && + refcount_inc_not_zero(&srv->refcount)) { + mutex_unlock(&ctx->srv_mutex); + return srv; + } + } + + /* need to allocate a new srv */ + srv = kzalloc(sizeof(*srv), GFP_KERNEL); + if (!srv) { + mutex_unlock(&ctx->srv_mutex); + return NULL; + } + + INIT_LIST_HEAD(&srv->paths_list); + mutex_init(&srv->paths_mutex); + mutex_init(&srv->paths_ev_mutex); + uuid_copy(&srv->paths_uuid, paths_uuid); + srv->queue_depth = sess_queue_depth; + srv->ctx = ctx; + device_initialize(&srv->dev); + srv->dev.release = rtrs_srv_dev_release; + list_add(&srv->ctx_list, &ctx->srv_list); mutex_unlock(&ctx->srv_mutex); + srv->chunks = kcalloc(srv->queue_depth, sizeof(*srv->chunks), + GFP_KERNEL); + if (!srv->chunks) + goto err_free_srv; + + for (i = 0; i < srv->queue_depth; i++) { + srv->chunks[i] = mempool_alloc(chunk_pool, GFP_KERNEL); + if (!srv->chunks[i]) + goto err_free_chunks; + } + refcount_set(&srv->refcount, 1); + return srv; + +err_free_chunks: + while (i--) + mempool_free(srv->chunks[i], chunk_pool); + kfree(srv->chunks); + +err_free_srv: + kfree(srv); + return NULL; } static void put_srv(struct rtrs_srv *srv) @@ -1813,7 +1797,11 @@ static int rtrs_rdma_connect(struct rdma_cm_id *cm_id, } recon_cnt = le16_to_cpu(msg->recon_cnt); srv = get_or_create_srv(ctx, &msg->paths_uuid); - if (!srv) { + /* + * "refcount == 0" happens if a previous thread calls get_or_create_srv + * allocate srv, but chunks of srv are not allocated yet. + */ + if (!srv || refcount_read(&srv->refcount) == 0) { err = -ENOMEM; goto reject_w_err; }