Implement proper layering in the memory policy layer. Separate userspace functions and provide do_xx functions. Split up get_nodes into get_nodes( which gets the nodes) and contextualize_policy which restricts the nodes to those accessible to the task and updates cpusets. Signed-off-by: Christoph Lameter Index: linux-2.6.13-rc5/mm/mempolicy.c =================================================================== --- linux-2.6.13-rc5.orig/mm/mempolicy.c 2005-08-05 09:25:26.000000000 -0700 +++ linux-2.6.13-rc5/mm/mempolicy.c 2005-08-05 12:07:36.000000000 -0700 @@ -132,54 +132,6 @@ static int mpol_check_policy(int mode, n return nodes_online(nodes); } -/* Copy a node mask from user space. */ -static int get_nodes(nodemask_t *nodes, unsigned long __user *nmask, - unsigned long maxnode, int mode) -{ - unsigned long k; - unsigned long nlongs; - unsigned long endmask; - - --maxnode; - nodes_clear(*nodes); - if (maxnode == 0 || !nmask) - return 0; - - nlongs = BITS_TO_LONGS(maxnode); - if ((maxnode % BITS_PER_LONG) == 0) - endmask = ~0UL; - else - endmask = (1UL << (maxnode % BITS_PER_LONG)) - 1; - - /* When the user specified more nodes than supported just check - if the non supported part is all zero. */ - if (nlongs > BITS_TO_LONGS(MAX_NUMNODES)) { - if (nlongs > PAGE_SIZE/sizeof(long)) - return -EINVAL; - for (k = BITS_TO_LONGS(MAX_NUMNODES); k < nlongs; k++) { - unsigned long t; - if (get_user(t, nmask + k)) - return -EFAULT; - if (k == nlongs - 1) { - if (t & endmask) - return -EINVAL; - } else if (t) - return -EINVAL; - } - nlongs = BITS_TO_LONGS(MAX_NUMNODES); - endmask = ~0UL; - } - - if (copy_from_user(nodes, nmask, nlongs*sizeof(unsigned long))) - return -EFAULT; - nodes_addr(*nodes)[nlongs - 1] &= endmask; - /* Update current mems_allowed */ - cpuset_update_current_mems_allowed(); - /* Ignore nodes not set in current->mems_allowed */ - cpuset_restrict_to_mems_allowed(nodes); - return mpol_check_policy(mode, nodes); -} - /* Generate a custom zonelist for the BIND policy. */ static struct zonelist *bind_zonelist(nodemask_t *nodes) { @@ -393,18 +345,23 @@ static int mbind_range(struct vm_area_st return err; } -/* Change policy for a memory range */ -asmlinkage long sys_mbind(unsigned long start, unsigned long len, - unsigned long mode, - unsigned long __user *nmask, unsigned long maxnode, - unsigned flags) +static int contextualize_policy(int mode, nodemask_t *nodes) +{ + /* Update current mems_allowed */ + cpuset_update_current_mems_allowed(); + /* Ignore nodes not set in current->mems_allowed */ + cpuset_restrict_to_mems_allowed(nodes); + return mpol_check_policy(mode, nodes); +} + +long do_mbind(unsigned long start, unsigned long len, + unsigned long mode, nodemask_t *nmask, unsigned long flags) { - struct vm_area_struct *vma; struct mm_struct *mm = current->mm; + struct vm_area_struct *vma; struct mempolicy *new; - unsigned long end; - nodemask_t nodes; int err; + int end; if ((flags & ~(unsigned long)(MPOL_MF_STRICT)) || mode > MPOL_MAX) return -EINVAL; @@ -418,12 +375,9 @@ asmlinkage long sys_mbind(unsigned long return -EINVAL; if (end == start) return 0; - - err = get_nodes(&nodes, nmask, maxnode, mode); - if (err) - return err; - - new = mpol_new(mode, &nodes); + if (!contextualize_policy(mode, nmask)) + return -EINVAL; + new = mpol_new(mode, nmask); if (IS_ERR(new)) return PTR_ERR(new); @@ -431,7 +385,7 @@ asmlinkage long sys_mbind(unsigned long mode,nodes[0]); down_write(&mm->mmap_sem); - vma = check_range(mm, start, end, &nodes, flags); + vma = check_range(mm, start, end, nmask, flags); err = PTR_ERR(vma); if (!IS_ERR(vma)) err = mbind_range(vma, start, end, new); @@ -441,19 +395,13 @@ asmlinkage long sys_mbind(unsigned long } /* Set the process memory policy */ -asmlinkage long sys_set_mempolicy(int mode, unsigned long __user *nmask, - unsigned long maxnode) +long do_set_mempolicy(int mode, nodemask_t *nodes) { - int err; struct mempolicy *new; - nodemask_t nodes; - if (mode < 0 || mode > MPOL_MAX) + if (!contextualize_policy(mode, nodes)) return -EINVAL; - err = get_nodes(&nodes, nmask, maxnode, mode); - if (err) - return err; - new = mpol_new(mode, &nodes); + new = mpol_new(mode, nodes); if (IS_ERR(new)) return PTR_ERR(new); mpol_free(current->mempolicy); @@ -506,37 +454,17 @@ static int lookup_node(struct mm_struct return err; } -/* Copy a kernel node mask to user space */ -static int copy_nodes_to_user(unsigned long __user *mask, unsigned long maxnode, - nodemask_t *nodes, unsigned nbytes) -{ - unsigned long copy = ALIGN(maxnode - 1, 64) / 8; - - if (copy > nbytes) { - if (copy > PAGE_SIZE) - return -EINVAL; - if (clear_user((char __user *)mask + nbytes, copy - nbytes)) - return -EFAULT; - copy = nbytes; - } - return copy_to_user(mask, nodes, copy) ? -EFAULT : 0; -} - /* Retrieve NUMA policy */ -asmlinkage long sys_get_mempolicy(int __user *policy, - unsigned long __user *nmask, - unsigned long maxnode, +long do_get_mempolicy(int *policy, nodemask_t *nmask, unsigned long addr, unsigned long flags) { - int err, pval; + int err; struct mm_struct *mm = current->mm; struct vm_area_struct *vma = NULL; struct mempolicy *pol = current->mempolicy; if (flags & ~(unsigned long)(MPOL_F_NODE|MPOL_F_ADDR)) return -EINVAL; - if (nmask != NULL && maxnode < MAX_NUMNODES) - return -EINVAL; if (flags & MPOL_F_ADDR) { down_read(&mm->mmap_sem); vma = find_vma_intersection(mm, addr, addr + 1); @@ -559,35 +487,151 @@ asmlinkage long sys_get_mempolicy(int __ err = lookup_node(mm, addr); if (err < 0) goto out; - pval = err; + *policy = err; } else if (pol == current->mempolicy && pol->policy == MPOL_INTERLEAVE) { - pval = current->il_next; + *policy = current->il_next; } else { err = -EINVAL; goto out; } } else - pval = pol->policy; + *policy = pol->policy; if (vma) { up_read(¤t->mm->mmap_sem); vma = NULL; } + err = 0; + if (nmask) + get_zonemask(pol, nmask); + +out: + if (vma) + up_read(¤t->mm->mmap_sem); + return err; +} + +/* + * Interface to user space + */ + +/* Copy a node mask from user space. */ +static int get_nodes(nodemask_t *nodes, unsigned long __user *nmask, + unsigned long maxnode) +{ + unsigned long k; + unsigned long nlongs; + unsigned long endmask; + + --maxnode; + nodes_clear(*nodes); + if (maxnode == 0 || !nmask) + return 0; + + nlongs = BITS_TO_LONGS(maxnode); + if ((maxnode % BITS_PER_LONG) == 0) + endmask = ~0UL; + else + endmask = (1UL << (maxnode % BITS_PER_LONG)) - 1; + + /* When the user specified more nodes than supported just check + if the non supported part is all zero. */ + if (nlongs > BITS_TO_LONGS(MAX_NUMNODES)) { + if (nlongs > PAGE_SIZE/sizeof(long)) + return -EINVAL; + for (k = BITS_TO_LONGS(MAX_NUMNODES); k < nlongs; k++) { + unsigned long t; + if (get_user(t, nmask + k)) + return -EFAULT; + if (k == nlongs - 1) { + if (t & endmask) + return -EINVAL; + } else if (t) + return -EINVAL; + } + nlongs = BITS_TO_LONGS(MAX_NUMNODES); + endmask = ~0UL; + } + + if (copy_from_user(nodes, nmask, nlongs*sizeof(unsigned long))) + return -EFAULT; + nodes_addr(*nodes)[nlongs - 1] &= endmask; + return 0; +} + +/* Change policy for a memory range */ +asmlinkage long sys_mbind(unsigned long start, unsigned long len, + unsigned long mode, + unsigned long __user *nmask, unsigned long maxnode, + unsigned flags) +{ + nodemask_t nodes; + int err; + + + err = get_nodes(&nodes, nmask, maxnode); + if (err) + return err; + + return do_mbind(start, len, mode, &nodes, flags); +} + +/* Copy a kernel node mask to user space */ +static int copy_nodes_to_user(unsigned long __user *mask, unsigned long maxnode, + nodemask_t *nodes, unsigned nbytes) +{ + unsigned long copy = ALIGN(maxnode - 1, 64) / 8; + + if (copy > nbytes) { + if (copy > PAGE_SIZE) + return -EINVAL; + if (clear_user((char __user *)mask + nbytes, copy - nbytes)) + return -EFAULT; + copy = nbytes; + } + return copy_to_user(mask, nodes, copy) ? -EFAULT : 0; +} + +/* Set the process memory policy */ +asmlinkage long sys_set_mempolicy(int mode, unsigned long __user *nmask, + unsigned long maxnode) +{ + int err; + nodemask_t nodes; + + if (mode < 0 || mode > MPOL_MAX) + return -EINVAL; + err = get_nodes(&nodes, nmask, maxnode); + if (err) + return err; + return do_set_mempolicy(mode, &nodes); +} + +/* Retrieve NUMA policy */ +asmlinkage long sys_get_mempolicy(int __user *policy, + unsigned long __user *nmask, + unsigned long maxnode, + unsigned long addr, unsigned long flags) +{ + int err, pval; + nodemask_t nodes; + + if (nmask != NULL && maxnode < MAX_NUMNODES) + return -EINVAL; + + err = do_get_mempolicy(&pval, &nodes, addr, flags); + + if (err) + return err; + if (policy && put_user(pval, policy)) return -EFAULT; - err = 0; - if (nmask) { - nodemask_t nodes; - get_zonemask(pol, &nodes); + if (nmask) err = copy_nodes_to_user(nmask, maxnode, &nodes, sizeof(nodes)); - } - out: - if (vma) - up_read(¤t->mm->mmap_sem); return err; }