diff --git a/mm/madvise.c b/mm/madvise.c index 08b207f8e61e..fa5dae5a7723 100644 --- a/mm/madvise.c +++ b/mm/madvise.c @@ -1574,6 +1574,50 @@ int madvise_set_anon_name(struct mm_struct *mm, unsigned long start, madvise_vma_anon_name); } #endif /* CONFIG_ANON_VMA_NAME */ + +#ifdef CONFIG_MEMORY_FAILURE +static bool is_memory_failure(int behavior) +{ + switch (behavior) { + case MADV_HWPOISON: + case MADV_SOFT_OFFLINE: + return true; + default: + return false; + } +} +#else +static bool is_memory_failure(int behavior) +{ + return false; +} +#endif + +static int madvise_lock(struct mm_struct *mm, int behavior) +{ + if (is_memory_failure(behavior)) + return 0; + + if (madvise_need_mmap_write(behavior)) { + if (mmap_write_lock_killable(mm)) + return -EINTR; + } else { + mmap_read_lock(mm); + } + return 0; +} + +static void madvise_unlock(struct mm_struct *mm, int behavior) +{ + if (is_memory_failure(behavior)) + return; + + if (madvise_need_mmap_write(behavior)) + mmap_write_unlock(mm); + else + mmap_read_unlock(mm); +} + /* * The madvise(2) system call. * @@ -1650,7 +1694,6 @@ int do_madvise(struct mm_struct *mm, unsigned long start, size_t len_in, int beh { unsigned long end; int error; - int write; size_t len; struct blk_plug plug; @@ -1672,19 +1715,15 @@ int do_madvise(struct mm_struct *mm, unsigned long start, size_t len_in, int beh if (end == start) return 0; + error = madvise_lock(mm, behavior); + if (error) + return error; + #ifdef CONFIG_MEMORY_FAILURE if (behavior == MADV_HWPOISON || behavior == MADV_SOFT_OFFLINE) return madvise_inject_error(behavior, start, start + len_in); #endif - write = madvise_need_mmap_write(behavior); - if (write) { - if (mmap_write_lock_killable(mm)) - return -EINTR; - } else { - mmap_read_lock(mm); - } - start = untagged_addr_remote(mm, start); end = start + len; @@ -1701,10 +1740,7 @@ int do_madvise(struct mm_struct *mm, unsigned long start, size_t len_in, int beh } blk_finish_plug(&plug); - if (write) - mmap_write_unlock(mm); - else - mmap_read_unlock(mm); + madvise_unlock(mm, behavior); return error; }