rpc: implement new upcall
[linux-2.6] / net / sunrpc / auth_gss / auth_gss.c
1 /*
2  * linux/net/sunrpc/auth_gss/auth_gss.c
3  *
4  * RPCSEC_GSS client authentication.
5  *
6  *  Copyright (c) 2000 The Regents of the University of Michigan.
7  *  All rights reserved.
8  *
9  *  Dug Song       <dugsong@monkey.org>
10  *  Andy Adamson   <andros@umich.edu>
11  *
12  *  Redistribution and use in source and binary forms, with or without
13  *  modification, are permitted provided that the following conditions
14  *  are met:
15  *
16  *  1. Redistributions of source code must retain the above copyright
17  *     notice, this list of conditions and the following disclaimer.
18  *  2. Redistributions in binary form must reproduce the above copyright
19  *     notice, this list of conditions and the following disclaimer in the
20  *     documentation and/or other materials provided with the distribution.
21  *  3. Neither the name of the University nor the names of its
22  *     contributors may be used to endorse or promote products derived
23  *     from this software without specific prior written permission.
24  *
25  *  THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
26  *  WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
27  *  MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28  *  DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
29  *  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
32  *  BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
33  *  LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
34  *  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
35  *  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36  */
37
38
39 #include <linux/module.h>
40 #include <linux/init.h>
41 #include <linux/types.h>
42 #include <linux/slab.h>
43 #include <linux/sched.h>
44 #include <linux/pagemap.h>
45 #include <linux/sunrpc/clnt.h>
46 #include <linux/sunrpc/auth.h>
47 #include <linux/sunrpc/auth_gss.h>
48 #include <linux/sunrpc/svcauth_gss.h>
49 #include <linux/sunrpc/gss_err.h>
50 #include <linux/workqueue.h>
51 #include <linux/sunrpc/rpc_pipe_fs.h>
52 #include <linux/sunrpc/gss_api.h>
53 #include <asm/uaccess.h>
54
55 static const struct rpc_authops authgss_ops;
56
57 static const struct rpc_credops gss_credops;
58 static const struct rpc_credops gss_nullops;
59
60 #ifdef RPC_DEBUG
61 # define RPCDBG_FACILITY        RPCDBG_AUTH
62 #endif
63
64 #define GSS_CRED_SLACK          1024
65 /* length of a krb5 verifier (48), plus data added before arguments when
66  * using integrity (two 4-byte integers): */
67 #define GSS_VERF_SLACK          100
68
69 struct gss_auth {
70         struct kref kref;
71         struct rpc_auth rpc_auth;
72         struct gss_api_mech *mech;
73         enum rpc_gss_svc service;
74         struct rpc_clnt *client;
75         /*
76          * There are two upcall pipes; dentry[1], named "gssd", is used
77          * for the new text-based upcall; dentry[0] is named after the
78          * mechanism (for example, "krb5") and exists for
79          * backwards-compatibility with older gssd's.
80          */
81         struct dentry *dentry[2];
82 };
83
84 /* pipe_version >= 0 if and only if someone has a pipe open. */
85 static int pipe_version = -1;
86 static atomic_t pipe_users = ATOMIC_INIT(0);
87 static DEFINE_SPINLOCK(pipe_version_lock);
88 static struct rpc_wait_queue pipe_version_rpc_waitqueue;
89 static DECLARE_WAIT_QUEUE_HEAD(pipe_version_waitqueue);
90
91 static void gss_free_ctx(struct gss_cl_ctx *);
92 static struct rpc_pipe_ops gss_upcall_ops_v0;
93 static struct rpc_pipe_ops gss_upcall_ops_v1;
94
95 static inline struct gss_cl_ctx *
96 gss_get_ctx(struct gss_cl_ctx *ctx)
97 {
98         atomic_inc(&ctx->count);
99         return ctx;
100 }
101
102 static inline void
103 gss_put_ctx(struct gss_cl_ctx *ctx)
104 {
105         if (atomic_dec_and_test(&ctx->count))
106                 gss_free_ctx(ctx);
107 }
108
109 /* gss_cred_set_ctx:
110  * called by gss_upcall_callback and gss_create_upcall in order
111  * to set the gss context. The actual exchange of an old context
112  * and a new one is protected by the inode->i_lock.
113  */
114 static void
115 gss_cred_set_ctx(struct rpc_cred *cred, struct gss_cl_ctx *ctx)
116 {
117         struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
118
119         if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags))
120                 return;
121         gss_get_ctx(ctx);
122         rcu_assign_pointer(gss_cred->gc_ctx, ctx);
123         set_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
124         smp_mb__before_clear_bit();
125         clear_bit(RPCAUTH_CRED_NEW, &cred->cr_flags);
126 }
127
128 static const void *
129 simple_get_bytes(const void *p, const void *end, void *res, size_t len)
130 {
131         const void *q = (const void *)((const char *)p + len);
132         if (unlikely(q > end || q < p))
133                 return ERR_PTR(-EFAULT);
134         memcpy(res, p, len);
135         return q;
136 }
137
138 static inline const void *
139 simple_get_netobj(const void *p, const void *end, struct xdr_netobj *dest)
140 {
141         const void *q;
142         unsigned int len;
143
144         p = simple_get_bytes(p, end, &len, sizeof(len));
145         if (IS_ERR(p))
146                 return p;
147         q = (const void *)((const char *)p + len);
148         if (unlikely(q > end || q < p))
149                 return ERR_PTR(-EFAULT);
150         dest->data = kmemdup(p, len, GFP_NOFS);
151         if (unlikely(dest->data == NULL))
152                 return ERR_PTR(-ENOMEM);
153         dest->len = len;
154         return q;
155 }
156
157 static struct gss_cl_ctx *
158 gss_cred_get_ctx(struct rpc_cred *cred)
159 {
160         struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
161         struct gss_cl_ctx *ctx = NULL;
162
163         rcu_read_lock();
164         if (gss_cred->gc_ctx)
165                 ctx = gss_get_ctx(gss_cred->gc_ctx);
166         rcu_read_unlock();
167         return ctx;
168 }
169
170 static struct gss_cl_ctx *
171 gss_alloc_context(void)
172 {
173         struct gss_cl_ctx *ctx;
174
175         ctx = kzalloc(sizeof(*ctx), GFP_NOFS);
176         if (ctx != NULL) {
177                 ctx->gc_proc = RPC_GSS_PROC_DATA;
178                 ctx->gc_seq = 1;        /* NetApp 6.4R1 doesn't accept seq. no. 0 */
179                 spin_lock_init(&ctx->gc_seq_lock);
180                 atomic_set(&ctx->count,1);
181         }
182         return ctx;
183 }
184
185 #define GSSD_MIN_TIMEOUT (60 * 60)
186 static const void *
187 gss_fill_context(const void *p, const void *end, struct gss_cl_ctx *ctx, struct gss_api_mech *gm)
188 {
189         const void *q;
190         unsigned int seclen;
191         unsigned int timeout;
192         u32 window_size;
193         int ret;
194
195         /* First unsigned int gives the lifetime (in seconds) of the cred */
196         p = simple_get_bytes(p, end, &timeout, sizeof(timeout));
197         if (IS_ERR(p))
198                 goto err;
199         if (timeout == 0)
200                 timeout = GSSD_MIN_TIMEOUT;
201         ctx->gc_expiry = jiffies + (unsigned long)timeout * HZ * 3 / 4;
202         /* Sequence number window. Determines the maximum number of simultaneous requests */
203         p = simple_get_bytes(p, end, &window_size, sizeof(window_size));
204         if (IS_ERR(p))
205                 goto err;
206         ctx->gc_win = window_size;
207         /* gssd signals an error by passing ctx->gc_win = 0: */
208         if (ctx->gc_win == 0) {
209                 /* in which case, p points to  an error code which we ignore */
210                 p = ERR_PTR(-EACCES);
211                 goto err;
212         }
213         /* copy the opaque wire context */
214         p = simple_get_netobj(p, end, &ctx->gc_wire_ctx);
215         if (IS_ERR(p))
216                 goto err;
217         /* import the opaque security context */
218         p  = simple_get_bytes(p, end, &seclen, sizeof(seclen));
219         if (IS_ERR(p))
220                 goto err;
221         q = (const void *)((const char *)p + seclen);
222         if (unlikely(q > end || q < p)) {
223                 p = ERR_PTR(-EFAULT);
224                 goto err;
225         }
226         ret = gss_import_sec_context(p, seclen, gm, &ctx->gc_gss_ctx);
227         if (ret < 0) {
228                 p = ERR_PTR(ret);
229                 goto err;
230         }
231         return q;
232 err:
233         dprintk("RPC:       gss_fill_context returning %ld\n", -PTR_ERR(p));
234         return p;
235 }
236
237 #define UPCALL_BUF_LEN 128
238
239 struct gss_upcall_msg {
240         atomic_t count;
241         uid_t   uid;
242         struct rpc_pipe_msg msg;
243         struct list_head list;
244         struct gss_auth *auth;
245         struct rpc_inode *inode;
246         struct rpc_wait_queue rpc_waitqueue;
247         wait_queue_head_t waitqueue;
248         struct gss_cl_ctx *ctx;
249         char databuf[UPCALL_BUF_LEN];
250 };
251
252 static int get_pipe_version(void)
253 {
254         int ret;
255
256         spin_lock(&pipe_version_lock);
257         if (pipe_version >= 0) {
258                 atomic_inc(&pipe_users);
259                 ret = pipe_version;
260         } else
261                 ret = -EAGAIN;
262         spin_unlock(&pipe_version_lock);
263         return ret;
264 }
265
266 static void put_pipe_version(void)
267 {
268         if (atomic_dec_and_lock(&pipe_users, &pipe_version_lock)) {
269                 pipe_version = -1;
270                 spin_unlock(&pipe_version_lock);
271         }
272 }
273
274 static void
275 gss_release_msg(struct gss_upcall_msg *gss_msg)
276 {
277         if (!atomic_dec_and_test(&gss_msg->count))
278                 return;
279         put_pipe_version();
280         BUG_ON(!list_empty(&gss_msg->list));
281         if (gss_msg->ctx != NULL)
282                 gss_put_ctx(gss_msg->ctx);
283         rpc_destroy_wait_queue(&gss_msg->rpc_waitqueue);
284         kfree(gss_msg);
285 }
286
287 static struct gss_upcall_msg *
288 __gss_find_upcall(struct rpc_inode *rpci, uid_t uid)
289 {
290         struct gss_upcall_msg *pos;
291         list_for_each_entry(pos, &rpci->in_downcall, list) {
292                 if (pos->uid != uid)
293                         continue;
294                 atomic_inc(&pos->count);
295                 dprintk("RPC:       gss_find_upcall found msg %p\n", pos);
296                 return pos;
297         }
298         dprintk("RPC:       gss_find_upcall found nothing\n");
299         return NULL;
300 }
301
302 /* Try to add an upcall to the pipefs queue.
303  * If an upcall owned by our uid already exists, then we return a reference
304  * to that upcall instead of adding the new upcall.
305  */
306 static inline struct gss_upcall_msg *
307 gss_add_msg(struct gss_auth *gss_auth, struct gss_upcall_msg *gss_msg)
308 {
309         struct rpc_inode *rpci = gss_msg->inode;
310         struct inode *inode = &rpci->vfs_inode;
311         struct gss_upcall_msg *old;
312
313         spin_lock(&inode->i_lock);
314         old = __gss_find_upcall(rpci, gss_msg->uid);
315         if (old == NULL) {
316                 atomic_inc(&gss_msg->count);
317                 list_add(&gss_msg->list, &rpci->in_downcall);
318         } else
319                 gss_msg = old;
320         spin_unlock(&inode->i_lock);
321         return gss_msg;
322 }
323
324 static void
325 __gss_unhash_msg(struct gss_upcall_msg *gss_msg)
326 {
327         list_del_init(&gss_msg->list);
328         rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
329         wake_up_all(&gss_msg->waitqueue);
330         atomic_dec(&gss_msg->count);
331 }
332
333 static void
334 gss_unhash_msg(struct gss_upcall_msg *gss_msg)
335 {
336         struct inode *inode = &gss_msg->inode->vfs_inode;
337
338         if (list_empty(&gss_msg->list))
339                 return;
340         spin_lock(&inode->i_lock);
341         if (!list_empty(&gss_msg->list))
342                 __gss_unhash_msg(gss_msg);
343         spin_unlock(&inode->i_lock);
344 }
345
346 static void
347 gss_upcall_callback(struct rpc_task *task)
348 {
349         struct gss_cred *gss_cred = container_of(task->tk_msg.rpc_cred,
350                         struct gss_cred, gc_base);
351         struct gss_upcall_msg *gss_msg = gss_cred->gc_upcall;
352         struct inode *inode = &gss_msg->inode->vfs_inode;
353
354         spin_lock(&inode->i_lock);
355         if (gss_msg->ctx)
356                 gss_cred_set_ctx(task->tk_msg.rpc_cred, gss_msg->ctx);
357         else
358                 task->tk_status = gss_msg->msg.errno;
359         gss_cred->gc_upcall = NULL;
360         rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
361         spin_unlock(&inode->i_lock);
362         gss_release_msg(gss_msg);
363 }
364
365 static void gss_encode_v0_msg(struct gss_upcall_msg *gss_msg)
366 {
367         gss_msg->msg.data = &gss_msg->uid;
368         gss_msg->msg.len = sizeof(gss_msg->uid);
369 }
370
371 static void gss_encode_v1_msg(struct gss_upcall_msg *gss_msg)
372 {
373         gss_msg->msg.len = sprintf(gss_msg->databuf, "mech=%s uid=%d\n",
374                                    gss_msg->auth->mech->gm_name,
375                                    gss_msg->uid);
376         gss_msg->msg.data = gss_msg->databuf;
377         BUG_ON(gss_msg->msg.len > UPCALL_BUF_LEN);
378 }
379
380 static void gss_encode_msg(struct gss_upcall_msg *gss_msg)
381 {
382         if (pipe_version == 0)
383                 gss_encode_v0_msg(gss_msg);
384         else /* pipe_version == 1 */
385                 gss_encode_v1_msg(gss_msg);
386 }
387
388 static inline struct gss_upcall_msg *
389 gss_alloc_msg(struct gss_auth *gss_auth, uid_t uid)
390 {
391         struct gss_upcall_msg *gss_msg;
392         int vers;
393
394         gss_msg = kzalloc(sizeof(*gss_msg), GFP_NOFS);
395         if (gss_msg == NULL)
396                 return ERR_PTR(-ENOMEM);
397         vers = get_pipe_version();
398         if (vers < 0) {
399                 kfree(gss_msg);
400                 return ERR_PTR(vers);
401         }
402         gss_msg->inode = RPC_I(gss_auth->dentry[vers]->d_inode);
403         INIT_LIST_HEAD(&gss_msg->list);
404         rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq");
405         init_waitqueue_head(&gss_msg->waitqueue);
406         atomic_set(&gss_msg->count, 1);
407         gss_msg->uid = uid;
408         gss_msg->auth = gss_auth;
409         gss_encode_msg(gss_msg);
410         return gss_msg;
411 }
412
413 static struct gss_upcall_msg *
414 gss_setup_upcall(struct rpc_clnt *clnt, struct gss_auth *gss_auth, struct rpc_cred *cred)
415 {
416         struct gss_cred *gss_cred = container_of(cred,
417                         struct gss_cred, gc_base);
418         struct gss_upcall_msg *gss_new, *gss_msg;
419         uid_t uid = cred->cr_uid;
420
421         /* Special case: rpc.gssd assumes that uid == 0 implies machine creds */
422         if (gss_cred->gc_machine_cred != 0)
423                 uid = 0;
424
425         gss_new = gss_alloc_msg(gss_auth, uid);
426         if (IS_ERR(gss_new))
427                 return gss_new;
428         gss_msg = gss_add_msg(gss_auth, gss_new);
429         if (gss_msg == gss_new) {
430                 struct inode *inode = &gss_new->inode->vfs_inode;
431                 int res = rpc_queue_upcall(inode, &gss_new->msg);
432                 if (res) {
433                         gss_unhash_msg(gss_new);
434                         gss_msg = ERR_PTR(res);
435                 }
436         } else
437                 gss_release_msg(gss_new);
438         return gss_msg;
439 }
440
441 static void warn_gssd(void)
442 {
443         static unsigned long ratelimit;
444         unsigned long now = jiffies;
445
446         if (time_after(now, ratelimit)) {
447                 printk(KERN_WARNING "RPC: AUTH_GSS upcall timed out.\n"
448                                 "Please check user daemon is running.\n");
449                 ratelimit = now + 15*HZ;
450         }
451 }
452
453 static inline int
454 gss_refresh_upcall(struct rpc_task *task)
455 {
456         struct rpc_cred *cred = task->tk_msg.rpc_cred;
457         struct gss_auth *gss_auth = container_of(cred->cr_auth,
458                         struct gss_auth, rpc_auth);
459         struct gss_cred *gss_cred = container_of(cred,
460                         struct gss_cred, gc_base);
461         struct gss_upcall_msg *gss_msg;
462         struct inode *inode;
463         int err = 0;
464
465         dprintk("RPC: %5u gss_refresh_upcall for uid %u\n", task->tk_pid,
466                                                                 cred->cr_uid);
467         gss_msg = gss_setup_upcall(task->tk_client, gss_auth, cred);
468         if (IS_ERR(gss_msg) == -EAGAIN) {
469                 /* XXX: warning on the first, under the assumption we
470                  * shouldn't normally hit this case on a refresh. */
471                 warn_gssd();
472                 task->tk_timeout = 15*HZ;
473                 rpc_sleep_on(&pipe_version_rpc_waitqueue, task, NULL);
474                 return 0;
475         }
476         if (IS_ERR(gss_msg)) {
477                 err = PTR_ERR(gss_msg);
478                 goto out;
479         }
480         inode = &gss_msg->inode->vfs_inode;
481         spin_lock(&inode->i_lock);
482         if (gss_cred->gc_upcall != NULL)
483                 rpc_sleep_on(&gss_cred->gc_upcall->rpc_waitqueue, task, NULL);
484         else if (gss_msg->ctx != NULL) {
485                 gss_cred_set_ctx(task->tk_msg.rpc_cred, gss_msg->ctx);
486                 gss_cred->gc_upcall = NULL;
487                 rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
488         } else if (gss_msg->msg.errno >= 0) {
489                 task->tk_timeout = 0;
490                 gss_cred->gc_upcall = gss_msg;
491                 /* gss_upcall_callback will release the reference to gss_upcall_msg */
492                 atomic_inc(&gss_msg->count);
493                 rpc_sleep_on(&gss_msg->rpc_waitqueue, task, gss_upcall_callback);
494         } else
495                 err = gss_msg->msg.errno;
496         spin_unlock(&inode->i_lock);
497         gss_release_msg(gss_msg);
498 out:
499         dprintk("RPC: %5u gss_refresh_upcall for uid %u result %d\n",
500                         task->tk_pid, cred->cr_uid, err);
501         return err;
502 }
503
504 static inline int
505 gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred)
506 {
507         struct inode *inode;
508         struct rpc_cred *cred = &gss_cred->gc_base;
509         struct gss_upcall_msg *gss_msg;
510         DEFINE_WAIT(wait);
511         int err = 0;
512
513         dprintk("RPC:       gss_upcall for uid %u\n", cred->cr_uid);
514 retry:
515         gss_msg = gss_setup_upcall(gss_auth->client, gss_auth, cred);
516         if (PTR_ERR(gss_msg) == -EAGAIN) {
517                 err = wait_event_interruptible_timeout(pipe_version_waitqueue,
518                                 pipe_version >= 0, 15*HZ);
519                 if (err)
520                         goto out;
521                 if (pipe_version < 0)
522                         warn_gssd();
523                 goto retry;
524         }
525         if (IS_ERR(gss_msg)) {
526                 err = PTR_ERR(gss_msg);
527                 goto out;
528         }
529         inode = &gss_msg->inode->vfs_inode;
530         for (;;) {
531                 prepare_to_wait(&gss_msg->waitqueue, &wait, TASK_INTERRUPTIBLE);
532                 spin_lock(&inode->i_lock);
533                 if (gss_msg->ctx != NULL || gss_msg->msg.errno < 0) {
534                         break;
535                 }
536                 spin_unlock(&inode->i_lock);
537                 if (signalled()) {
538                         err = -ERESTARTSYS;
539                         goto out_intr;
540                 }
541                 schedule();
542         }
543         if (gss_msg->ctx)
544                 gss_cred_set_ctx(cred, gss_msg->ctx);
545         else
546                 err = gss_msg->msg.errno;
547         spin_unlock(&inode->i_lock);
548 out_intr:
549         finish_wait(&gss_msg->waitqueue, &wait);
550         gss_release_msg(gss_msg);
551 out:
552         dprintk("RPC:       gss_create_upcall for uid %u result %d\n",
553                         cred->cr_uid, err);
554         return err;
555 }
556
557 static ssize_t
558 gss_pipe_upcall(struct file *filp, struct rpc_pipe_msg *msg,
559                 char __user *dst, size_t buflen)
560 {
561         char *data = (char *)msg->data + msg->copied;
562         size_t mlen = min(msg->len, buflen);
563         unsigned long left;
564
565         left = copy_to_user(dst, data, mlen);
566         if (left == mlen) {
567                 msg->errno = -EFAULT;
568                 return -EFAULT;
569         }
570
571         mlen -= left;
572         msg->copied += mlen;
573         msg->errno = 0;
574         return mlen;
575 }
576
577 #define MSG_BUF_MAXSIZE 1024
578
579 static ssize_t
580 gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen)
581 {
582         const void *p, *end;
583         void *buf;
584         struct gss_upcall_msg *gss_msg;
585         struct inode *inode = filp->f_path.dentry->d_inode;
586         struct gss_cl_ctx *ctx;
587         uid_t uid;
588         ssize_t err = -EFBIG;
589
590         if (mlen > MSG_BUF_MAXSIZE)
591                 goto out;
592         err = -ENOMEM;
593         buf = kmalloc(mlen, GFP_NOFS);
594         if (!buf)
595                 goto out;
596
597         err = -EFAULT;
598         if (copy_from_user(buf, src, mlen))
599                 goto err;
600
601         end = (const void *)((char *)buf + mlen);
602         p = simple_get_bytes(buf, end, &uid, sizeof(uid));
603         if (IS_ERR(p)) {
604                 err = PTR_ERR(p);
605                 goto err;
606         }
607
608         err = -ENOMEM;
609         ctx = gss_alloc_context();
610         if (ctx == NULL)
611                 goto err;
612
613         err = -ENOENT;
614         /* Find a matching upcall */
615         spin_lock(&inode->i_lock);
616         gss_msg = __gss_find_upcall(RPC_I(inode), uid);
617         if (gss_msg == NULL) {
618                 spin_unlock(&inode->i_lock);
619                 goto err_put_ctx;
620         }
621         list_del_init(&gss_msg->list);
622         spin_unlock(&inode->i_lock);
623
624         p = gss_fill_context(p, end, ctx, gss_msg->auth->mech);
625         if (IS_ERR(p)) {
626                 err = PTR_ERR(p);
627                 gss_msg->msg.errno = (err == -EAGAIN) ? -EAGAIN : -EACCES;
628                 goto err_release_msg;
629         }
630         gss_msg->ctx = gss_get_ctx(ctx);
631         err = mlen;
632
633 err_release_msg:
634         spin_lock(&inode->i_lock);
635         __gss_unhash_msg(gss_msg);
636         spin_unlock(&inode->i_lock);
637         gss_release_msg(gss_msg);
638 err_put_ctx:
639         gss_put_ctx(ctx);
640 err:
641         kfree(buf);
642 out:
643         dprintk("RPC:       gss_pipe_downcall returning %Zd\n", err);
644         return err;
645 }
646
647 static int gss_pipe_open(struct inode *inode, int new_version)
648 {
649         int ret = 0;
650
651         spin_lock(&pipe_version_lock);
652         if (pipe_version < 0) {
653                 /* First open of any gss pipe determines the version: */
654                 pipe_version = new_version;
655                 rpc_wake_up(&pipe_version_rpc_waitqueue);
656                 wake_up(&pipe_version_waitqueue);
657         } else if (pipe_version != new_version) {
658                 /* Trying to open a pipe of a different version */
659                 ret = -EBUSY;
660                 goto out;
661         }
662         atomic_inc(&pipe_users);
663 out:
664         spin_unlock(&pipe_version_lock);
665         return ret;
666
667 }
668
669 static int gss_pipe_open_v0(struct inode *inode)
670 {
671         return gss_pipe_open(inode, 0);
672 }
673
674 static int gss_pipe_open_v1(struct inode *inode)
675 {
676         return gss_pipe_open(inode, 1);
677 }
678
679 static void
680 gss_pipe_release(struct inode *inode)
681 {
682         struct rpc_inode *rpci = RPC_I(inode);
683         struct gss_upcall_msg *gss_msg;
684
685         spin_lock(&inode->i_lock);
686         while (!list_empty(&rpci->in_downcall)) {
687
688                 gss_msg = list_entry(rpci->in_downcall.next,
689                                 struct gss_upcall_msg, list);
690                 gss_msg->msg.errno = -EPIPE;
691                 atomic_inc(&gss_msg->count);
692                 __gss_unhash_msg(gss_msg);
693                 spin_unlock(&inode->i_lock);
694                 gss_release_msg(gss_msg);
695                 spin_lock(&inode->i_lock);
696         }
697         spin_unlock(&inode->i_lock);
698
699         put_pipe_version();
700 }
701
702 static void
703 gss_pipe_destroy_msg(struct rpc_pipe_msg *msg)
704 {
705         struct gss_upcall_msg *gss_msg = container_of(msg, struct gss_upcall_msg, msg);
706
707         if (msg->errno < 0) {
708                 dprintk("RPC:       gss_pipe_destroy_msg releasing msg %p\n",
709                                 gss_msg);
710                 atomic_inc(&gss_msg->count);
711                 gss_unhash_msg(gss_msg);
712                 if (msg->errno == -ETIMEDOUT)
713                         warn_gssd();
714                 gss_release_msg(gss_msg);
715         }
716 }
717
718 /*
719  * NOTE: we have the opportunity to use different
720  * parameters based on the input flavor (which must be a pseudoflavor)
721  */
722 static struct rpc_auth *
723 gss_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor)
724 {
725         struct gss_auth *gss_auth;
726         struct rpc_auth * auth;
727         int err = -ENOMEM; /* XXX? */
728
729         dprintk("RPC:       creating GSS authenticator for client %p\n", clnt);
730
731         if (!try_module_get(THIS_MODULE))
732                 return ERR_PTR(err);
733         if (!(gss_auth = kmalloc(sizeof(*gss_auth), GFP_KERNEL)))
734                 goto out_dec;
735         gss_auth->client = clnt;
736         err = -EINVAL;
737         gss_auth->mech = gss_mech_get_by_pseudoflavor(flavor);
738         if (!gss_auth->mech) {
739                 printk(KERN_WARNING "%s: Pseudoflavor %d not found!\n",
740                                 __func__, flavor);
741                 goto err_free;
742         }
743         gss_auth->service = gss_pseudoflavor_to_service(gss_auth->mech, flavor);
744         if (gss_auth->service == 0)
745                 goto err_put_mech;
746         auth = &gss_auth->rpc_auth;
747         auth->au_cslack = GSS_CRED_SLACK >> 2;
748         auth->au_rslack = GSS_VERF_SLACK >> 2;
749         auth->au_ops = &authgss_ops;
750         auth->au_flavor = flavor;
751         atomic_set(&auth->au_count, 1);
752         kref_init(&gss_auth->kref);
753
754         /*
755          * Note: if we created the old pipe first, then someone who
756          * examined the directory at the right moment might conclude
757          * that we supported only the old pipe.  So we instead create
758          * the new pipe first.
759          */
760         gss_auth->dentry[1] = rpc_mkpipe(clnt->cl_dentry,
761                                          "gssd",
762                                          clnt, &gss_upcall_ops_v1,
763                                          RPC_PIPE_WAIT_FOR_OPEN);
764         if (IS_ERR(gss_auth->dentry[1])) {
765                 err = PTR_ERR(gss_auth->dentry[1]);
766                 goto err_put_mech;
767         }
768
769         gss_auth->dentry[0] = rpc_mkpipe(clnt->cl_dentry,
770                                          gss_auth->mech->gm_name,
771                                          clnt, &gss_upcall_ops_v0,
772                                          RPC_PIPE_WAIT_FOR_OPEN);
773         if (IS_ERR(gss_auth->dentry[0])) {
774                 err = PTR_ERR(gss_auth->dentry[0]);
775                 goto err_unlink_pipe_1;
776         }
777         err = rpcauth_init_credcache(auth);
778         if (err)
779                 goto err_unlink_pipe_0;
780
781         return auth;
782 err_unlink_pipe_0:
783         rpc_unlink(gss_auth->dentry[0]);
784 err_unlink_pipe_1:
785         rpc_unlink(gss_auth->dentry[1]);
786 err_put_mech:
787         gss_mech_put(gss_auth->mech);
788 err_free:
789         kfree(gss_auth);
790 out_dec:
791         module_put(THIS_MODULE);
792         return ERR_PTR(err);
793 }
794
795 static void
796 gss_free(struct gss_auth *gss_auth)
797 {
798         rpc_unlink(gss_auth->dentry[1]);
799         rpc_unlink(gss_auth->dentry[0]);
800         gss_mech_put(gss_auth->mech);
801
802         kfree(gss_auth);
803         module_put(THIS_MODULE);
804 }
805
806 static void
807 gss_free_callback(struct kref *kref)
808 {
809         struct gss_auth *gss_auth = container_of(kref, struct gss_auth, kref);
810
811         gss_free(gss_auth);
812 }
813
814 static void
815 gss_destroy(struct rpc_auth *auth)
816 {
817         struct gss_auth *gss_auth;
818
819         dprintk("RPC:       destroying GSS authenticator %p flavor %d\n",
820                         auth, auth->au_flavor);
821
822         rpcauth_destroy_credcache(auth);
823
824         gss_auth = container_of(auth, struct gss_auth, rpc_auth);
825         kref_put(&gss_auth->kref, gss_free_callback);
826 }
827
828 /*
829  * gss_destroying_context will cause the RPCSEC_GSS to send a NULL RPC call
830  * to the server with the GSS control procedure field set to
831  * RPC_GSS_PROC_DESTROY. This should normally cause the server to release
832  * all RPCSEC_GSS state associated with that context.
833  */
834 static int
835 gss_destroying_context(struct rpc_cred *cred)
836 {
837         struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
838         struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth);
839         struct rpc_task *task;
840
841         if (gss_cred->gc_ctx == NULL ||
842             test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) == 0)
843                 return 0;
844
845         gss_cred->gc_ctx->gc_proc = RPC_GSS_PROC_DESTROY;
846         cred->cr_ops = &gss_nullops;
847
848         /* Take a reference to ensure the cred will be destroyed either
849          * by the RPC call or by the put_rpccred() below */
850         get_rpccred(cred);
851
852         task = rpc_call_null(gss_auth->client, cred, RPC_TASK_ASYNC|RPC_TASK_SOFT);
853         if (!IS_ERR(task))
854                 rpc_put_task(task);
855
856         put_rpccred(cred);
857         return 1;
858 }
859
860 /* gss_destroy_cred (and gss_free_ctx) are used to clean up after failure
861  * to create a new cred or context, so they check that things have been
862  * allocated before freeing them. */
863 static void
864 gss_do_free_ctx(struct gss_cl_ctx *ctx)
865 {
866         dprintk("RPC:       gss_free_ctx\n");
867
868         kfree(ctx->gc_wire_ctx.data);
869         kfree(ctx);
870 }
871
872 static void
873 gss_free_ctx_callback(struct rcu_head *head)
874 {
875         struct gss_cl_ctx *ctx = container_of(head, struct gss_cl_ctx, gc_rcu);
876         gss_do_free_ctx(ctx);
877 }
878
879 static void
880 gss_free_ctx(struct gss_cl_ctx *ctx)
881 {
882         struct gss_ctx *gc_gss_ctx;
883
884         gc_gss_ctx = rcu_dereference(ctx->gc_gss_ctx);
885         rcu_assign_pointer(ctx->gc_gss_ctx, NULL);
886         call_rcu(&ctx->gc_rcu, gss_free_ctx_callback);
887         if (gc_gss_ctx)
888                 gss_delete_sec_context(&gc_gss_ctx);
889 }
890
891 static void
892 gss_free_cred(struct gss_cred *gss_cred)
893 {
894         dprintk("RPC:       gss_free_cred %p\n", gss_cred);
895         kfree(gss_cred);
896 }
897
898 static void
899 gss_free_cred_callback(struct rcu_head *head)
900 {
901         struct gss_cred *gss_cred = container_of(head, struct gss_cred, gc_base.cr_rcu);
902         gss_free_cred(gss_cred);
903 }
904
905 static void
906 gss_destroy_nullcred(struct rpc_cred *cred)
907 {
908         struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
909         struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth);
910         struct gss_cl_ctx *ctx = gss_cred->gc_ctx;
911
912         rcu_assign_pointer(gss_cred->gc_ctx, NULL);
913         call_rcu(&cred->cr_rcu, gss_free_cred_callback);
914         if (ctx)
915                 gss_put_ctx(ctx);
916         kref_put(&gss_auth->kref, gss_free_callback);
917 }
918
919 static void
920 gss_destroy_cred(struct rpc_cred *cred)
921 {
922
923         if (gss_destroying_context(cred))
924                 return;
925         gss_destroy_nullcred(cred);
926 }
927
928 /*
929  * Lookup RPCSEC_GSS cred for the current process
930  */
931 static struct rpc_cred *
932 gss_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags)
933 {
934         return rpcauth_lookup_credcache(auth, acred, flags);
935 }
936
937 static struct rpc_cred *
938 gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags)
939 {
940         struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth);
941         struct gss_cred *cred = NULL;
942         int err = -ENOMEM;
943
944         dprintk("RPC:       gss_create_cred for uid %d, flavor %d\n",
945                 acred->uid, auth->au_flavor);
946
947         if (!(cred = kzalloc(sizeof(*cred), GFP_NOFS)))
948                 goto out_err;
949
950         rpcauth_init_cred(&cred->gc_base, acred, auth, &gss_credops);
951         /*
952          * Note: in order to force a call to call_refresh(), we deliberately
953          * fail to flag the credential as RPCAUTH_CRED_UPTODATE.
954          */
955         cred->gc_base.cr_flags = 1UL << RPCAUTH_CRED_NEW;
956         cred->gc_service = gss_auth->service;
957         cred->gc_machine_cred = acred->machine_cred;
958         kref_get(&gss_auth->kref);
959         return &cred->gc_base;
960
961 out_err:
962         dprintk("RPC:       gss_create_cred failed with error %d\n", err);
963         return ERR_PTR(err);
964 }
965
966 static int
967 gss_cred_init(struct rpc_auth *auth, struct rpc_cred *cred)
968 {
969         struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth);
970         struct gss_cred *gss_cred = container_of(cred,struct gss_cred, gc_base);
971         int err;
972
973         do {
974                 err = gss_create_upcall(gss_auth, gss_cred);
975         } while (err == -EAGAIN);
976         return err;
977 }
978
979 static int
980 gss_match(struct auth_cred *acred, struct rpc_cred *rc, int flags)
981 {
982         struct gss_cred *gss_cred = container_of(rc, struct gss_cred, gc_base);
983
984         if (test_bit(RPCAUTH_CRED_NEW, &rc->cr_flags))
985                 goto out;
986         /* Don't match with creds that have expired. */
987         if (time_after(jiffies, gss_cred->gc_ctx->gc_expiry))
988                 return 0;
989         if (!test_bit(RPCAUTH_CRED_UPTODATE, &rc->cr_flags))
990                 return 0;
991 out:
992         if (acred->machine_cred != gss_cred->gc_machine_cred)
993                 return 0;
994         return (rc->cr_uid == acred->uid);
995 }
996
997 /*
998 * Marshal credentials.
999 * Maybe we should keep a cached credential for performance reasons.
1000 */
1001 static __be32 *
1002 gss_marshal(struct rpc_task *task, __be32 *p)
1003 {
1004         struct rpc_cred *cred = task->tk_msg.rpc_cred;
1005         struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
1006                                                  gc_base);
1007         struct gss_cl_ctx       *ctx = gss_cred_get_ctx(cred);
1008         __be32          *cred_len;
1009         struct rpc_rqst *req = task->tk_rqstp;
1010         u32             maj_stat = 0;
1011         struct xdr_netobj mic;
1012         struct kvec     iov;
1013         struct xdr_buf  verf_buf;
1014
1015         dprintk("RPC: %5u gss_marshal\n", task->tk_pid);
1016
1017         *p++ = htonl(RPC_AUTH_GSS);
1018         cred_len = p++;
1019
1020         spin_lock(&ctx->gc_seq_lock);
1021         req->rq_seqno = ctx->gc_seq++;
1022         spin_unlock(&ctx->gc_seq_lock);
1023
1024         *p++ = htonl((u32) RPC_GSS_VERSION);
1025         *p++ = htonl((u32) ctx->gc_proc);
1026         *p++ = htonl((u32) req->rq_seqno);
1027         *p++ = htonl((u32) gss_cred->gc_service);
1028         p = xdr_encode_netobj(p, &ctx->gc_wire_ctx);
1029         *cred_len = htonl((p - (cred_len + 1)) << 2);
1030
1031         /* We compute the checksum for the verifier over the xdr-encoded bytes
1032          * starting with the xid and ending at the end of the credential: */
1033         iov.iov_base = xprt_skip_transport_header(task->tk_xprt,
1034                                         req->rq_snd_buf.head[0].iov_base);
1035         iov.iov_len = (u8 *)p - (u8 *)iov.iov_base;
1036         xdr_buf_from_iov(&iov, &verf_buf);
1037
1038         /* set verifier flavor*/
1039         *p++ = htonl(RPC_AUTH_GSS);
1040
1041         mic.data = (u8 *)(p + 1);
1042         maj_stat = gss_get_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
1043         if (maj_stat == GSS_S_CONTEXT_EXPIRED) {
1044                 clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1045         } else if (maj_stat != 0) {
1046                 printk("gss_marshal: gss_get_mic FAILED (%d)\n", maj_stat);
1047                 goto out_put_ctx;
1048         }
1049         p = xdr_encode_opaque(p, NULL, mic.len);
1050         gss_put_ctx(ctx);
1051         return p;
1052 out_put_ctx:
1053         gss_put_ctx(ctx);
1054         return NULL;
1055 }
1056
1057 static int gss_renew_cred(struct rpc_task *task)
1058 {
1059         struct rpc_cred *oldcred = task->tk_msg.rpc_cred;
1060         struct gss_cred *gss_cred = container_of(oldcred,
1061                                                  struct gss_cred,
1062                                                  gc_base);
1063         struct rpc_auth *auth = oldcred->cr_auth;
1064         struct auth_cred acred = {
1065                 .uid = oldcred->cr_uid,
1066                 .machine_cred = gss_cred->gc_machine_cred,
1067         };
1068         struct rpc_cred *new;
1069
1070         new = gss_lookup_cred(auth, &acred, RPCAUTH_LOOKUP_NEW);
1071         if (IS_ERR(new))
1072                 return PTR_ERR(new);
1073         task->tk_msg.rpc_cred = new;
1074         put_rpccred(oldcred);
1075         return 0;
1076 }
1077
1078 /*
1079 * Refresh credentials. XXX - finish
1080 */
1081 static int
1082 gss_refresh(struct rpc_task *task)
1083 {
1084         struct rpc_cred *cred = task->tk_msg.rpc_cred;
1085         int ret = 0;
1086
1087         if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
1088                         !test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags)) {
1089                 ret = gss_renew_cred(task);
1090                 if (ret < 0)
1091                         goto out;
1092                 cred = task->tk_msg.rpc_cred;
1093         }
1094
1095         if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags))
1096                 ret = gss_refresh_upcall(task);
1097 out:
1098         return ret;
1099 }
1100
1101 /* Dummy refresh routine: used only when destroying the context */
1102 static int
1103 gss_refresh_null(struct rpc_task *task)
1104 {
1105         return -EACCES;
1106 }
1107
1108 static __be32 *
1109 gss_validate(struct rpc_task *task, __be32 *p)
1110 {
1111         struct rpc_cred *cred = task->tk_msg.rpc_cred;
1112         struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1113         __be32          seq;
1114         struct kvec     iov;
1115         struct xdr_buf  verf_buf;
1116         struct xdr_netobj mic;
1117         u32             flav,len;
1118         u32             maj_stat;
1119
1120         dprintk("RPC: %5u gss_validate\n", task->tk_pid);
1121
1122         flav = ntohl(*p++);
1123         if ((len = ntohl(*p++)) > RPC_MAX_AUTH_SIZE)
1124                 goto out_bad;
1125         if (flav != RPC_AUTH_GSS)
1126                 goto out_bad;
1127         seq = htonl(task->tk_rqstp->rq_seqno);
1128         iov.iov_base = &seq;
1129         iov.iov_len = sizeof(seq);
1130         xdr_buf_from_iov(&iov, &verf_buf);
1131         mic.data = (u8 *)p;
1132         mic.len = len;
1133
1134         maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
1135         if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1136                 clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1137         if (maj_stat) {
1138                 dprintk("RPC: %5u gss_validate: gss_verify_mic returned "
1139                                 "error 0x%08x\n", task->tk_pid, maj_stat);
1140                 goto out_bad;
1141         }
1142         /* We leave it to unwrap to calculate au_rslack. For now we just
1143          * calculate the length of the verifier: */
1144         cred->cr_auth->au_verfsize = XDR_QUADLEN(len) + 2;
1145         gss_put_ctx(ctx);
1146         dprintk("RPC: %5u gss_validate: gss_verify_mic succeeded.\n",
1147                         task->tk_pid);
1148         return p + XDR_QUADLEN(len);
1149 out_bad:
1150         gss_put_ctx(ctx);
1151         dprintk("RPC: %5u gss_validate failed.\n", task->tk_pid);
1152         return NULL;
1153 }
1154
1155 static inline int
1156 gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1157                 kxdrproc_t encode, struct rpc_rqst *rqstp, __be32 *p, void *obj)
1158 {
1159         struct xdr_buf  *snd_buf = &rqstp->rq_snd_buf;
1160         struct xdr_buf  integ_buf;
1161         __be32          *integ_len = NULL;
1162         struct xdr_netobj mic;
1163         u32             offset;
1164         __be32          *q;
1165         struct kvec     *iov;
1166         u32             maj_stat = 0;
1167         int             status = -EIO;
1168
1169         integ_len = p++;
1170         offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
1171         *p++ = htonl(rqstp->rq_seqno);
1172
1173         status = encode(rqstp, p, obj);
1174         if (status)
1175                 return status;
1176
1177         if (xdr_buf_subsegment(snd_buf, &integ_buf,
1178                                 offset, snd_buf->len - offset))
1179                 return status;
1180         *integ_len = htonl(integ_buf.len);
1181
1182         /* guess whether we're in the head or the tail: */
1183         if (snd_buf->page_len || snd_buf->tail[0].iov_len)
1184                 iov = snd_buf->tail;
1185         else
1186                 iov = snd_buf->head;
1187         p = iov->iov_base + iov->iov_len;
1188         mic.data = (u8 *)(p + 1);
1189
1190         maj_stat = gss_get_mic(ctx->gc_gss_ctx, &integ_buf, &mic);
1191         status = -EIO; /* XXX? */
1192         if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1193                 clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1194         else if (maj_stat)
1195                 return status;
1196         q = xdr_encode_opaque(p, NULL, mic.len);
1197
1198         offset = (u8 *)q - (u8 *)p;
1199         iov->iov_len += offset;
1200         snd_buf->len += offset;
1201         return 0;
1202 }
1203
1204 static void
1205 priv_release_snd_buf(struct rpc_rqst *rqstp)
1206 {
1207         int i;
1208
1209         for (i=0; i < rqstp->rq_enc_pages_num; i++)
1210                 __free_page(rqstp->rq_enc_pages[i]);
1211         kfree(rqstp->rq_enc_pages);
1212 }
1213
1214 static int
1215 alloc_enc_pages(struct rpc_rqst *rqstp)
1216 {
1217         struct xdr_buf *snd_buf = &rqstp->rq_snd_buf;
1218         int first, last, i;
1219
1220         if (snd_buf->page_len == 0) {
1221                 rqstp->rq_enc_pages_num = 0;
1222                 return 0;
1223         }
1224
1225         first = snd_buf->page_base >> PAGE_CACHE_SHIFT;
1226         last = (snd_buf->page_base + snd_buf->page_len - 1) >> PAGE_CACHE_SHIFT;
1227         rqstp->rq_enc_pages_num = last - first + 1 + 1;
1228         rqstp->rq_enc_pages
1229                 = kmalloc(rqstp->rq_enc_pages_num * sizeof(struct page *),
1230                                 GFP_NOFS);
1231         if (!rqstp->rq_enc_pages)
1232                 goto out;
1233         for (i=0; i < rqstp->rq_enc_pages_num; i++) {
1234                 rqstp->rq_enc_pages[i] = alloc_page(GFP_NOFS);
1235                 if (rqstp->rq_enc_pages[i] == NULL)
1236                         goto out_free;
1237         }
1238         rqstp->rq_release_snd_buf = priv_release_snd_buf;
1239         return 0;
1240 out_free:
1241         for (i--; i >= 0; i--) {
1242                 __free_page(rqstp->rq_enc_pages[i]);
1243         }
1244 out:
1245         return -EAGAIN;
1246 }
1247
1248 static inline int
1249 gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1250                 kxdrproc_t encode, struct rpc_rqst *rqstp, __be32 *p, void *obj)
1251 {
1252         struct xdr_buf  *snd_buf = &rqstp->rq_snd_buf;
1253         u32             offset;
1254         u32             maj_stat;
1255         int             status;
1256         __be32          *opaque_len;
1257         struct page     **inpages;
1258         int             first;
1259         int             pad;
1260         struct kvec     *iov;
1261         char            *tmp;
1262
1263         opaque_len = p++;
1264         offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
1265         *p++ = htonl(rqstp->rq_seqno);
1266
1267         status = encode(rqstp, p, obj);
1268         if (status)
1269                 return status;
1270
1271         status = alloc_enc_pages(rqstp);
1272         if (status)
1273                 return status;
1274         first = snd_buf->page_base >> PAGE_CACHE_SHIFT;
1275         inpages = snd_buf->pages + first;
1276         snd_buf->pages = rqstp->rq_enc_pages;
1277         snd_buf->page_base -= first << PAGE_CACHE_SHIFT;
1278         /* Give the tail its own page, in case we need extra space in the
1279          * head when wrapping: */
1280         if (snd_buf->page_len || snd_buf->tail[0].iov_len) {
1281                 tmp = page_address(rqstp->rq_enc_pages[rqstp->rq_enc_pages_num - 1]);
1282                 memcpy(tmp, snd_buf->tail[0].iov_base, snd_buf->tail[0].iov_len);
1283                 snd_buf->tail[0].iov_base = tmp;
1284         }
1285         maj_stat = gss_wrap(ctx->gc_gss_ctx, offset, snd_buf, inpages);
1286         /* RPC_SLACK_SPACE should prevent this ever happening: */
1287         BUG_ON(snd_buf->len > snd_buf->buflen);
1288         status = -EIO;
1289         /* We're assuming that when GSS_S_CONTEXT_EXPIRED, the encryption was
1290          * done anyway, so it's safe to put the request on the wire: */
1291         if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1292                 clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1293         else if (maj_stat)
1294                 return status;
1295
1296         *opaque_len = htonl(snd_buf->len - offset);
1297         /* guess whether we're in the head or the tail: */
1298         if (snd_buf->page_len || snd_buf->tail[0].iov_len)
1299                 iov = snd_buf->tail;
1300         else
1301                 iov = snd_buf->head;
1302         p = iov->iov_base + iov->iov_len;
1303         pad = 3 - ((snd_buf->len - offset - 1) & 3);
1304         memset(p, 0, pad);
1305         iov->iov_len += pad;
1306         snd_buf->len += pad;
1307
1308         return 0;
1309 }
1310
1311 static int
1312 gss_wrap_req(struct rpc_task *task,
1313              kxdrproc_t encode, void *rqstp, __be32 *p, void *obj)
1314 {
1315         struct rpc_cred *cred = task->tk_msg.rpc_cred;
1316         struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
1317                         gc_base);
1318         struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1319         int             status = -EIO;
1320
1321         dprintk("RPC: %5u gss_wrap_req\n", task->tk_pid);
1322         if (ctx->gc_proc != RPC_GSS_PROC_DATA) {
1323                 /* The spec seems a little ambiguous here, but I think that not
1324                  * wrapping context destruction requests makes the most sense.
1325                  */
1326                 status = encode(rqstp, p, obj);
1327                 goto out;
1328         }
1329         switch (gss_cred->gc_service) {
1330                 case RPC_GSS_SVC_NONE:
1331                         status = encode(rqstp, p, obj);
1332                         break;
1333                 case RPC_GSS_SVC_INTEGRITY:
1334                         status = gss_wrap_req_integ(cred, ctx, encode,
1335                                                                 rqstp, p, obj);
1336                         break;
1337                 case RPC_GSS_SVC_PRIVACY:
1338                         status = gss_wrap_req_priv(cred, ctx, encode,
1339                                         rqstp, p, obj);
1340                         break;
1341         }
1342 out:
1343         gss_put_ctx(ctx);
1344         dprintk("RPC: %5u gss_wrap_req returning %d\n", task->tk_pid, status);
1345         return status;
1346 }
1347
1348 static inline int
1349 gss_unwrap_resp_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1350                 struct rpc_rqst *rqstp, __be32 **p)
1351 {
1352         struct xdr_buf  *rcv_buf = &rqstp->rq_rcv_buf;
1353         struct xdr_buf integ_buf;
1354         struct xdr_netobj mic;
1355         u32 data_offset, mic_offset;
1356         u32 integ_len;
1357         u32 maj_stat;
1358         int status = -EIO;
1359
1360         integ_len = ntohl(*(*p)++);
1361         if (integ_len & 3)
1362                 return status;
1363         data_offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base;
1364         mic_offset = integ_len + data_offset;
1365         if (mic_offset > rcv_buf->len)
1366                 return status;
1367         if (ntohl(*(*p)++) != rqstp->rq_seqno)
1368                 return status;
1369
1370         if (xdr_buf_subsegment(rcv_buf, &integ_buf, data_offset,
1371                                 mic_offset - data_offset))
1372                 return status;
1373
1374         if (xdr_buf_read_netobj(rcv_buf, &mic, mic_offset))
1375                 return status;
1376
1377         maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &integ_buf, &mic);
1378         if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1379                 clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1380         if (maj_stat != GSS_S_COMPLETE)
1381                 return status;
1382         return 0;
1383 }
1384
1385 static inline int
1386 gss_unwrap_resp_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1387                 struct rpc_rqst *rqstp, __be32 **p)
1388 {
1389         struct xdr_buf  *rcv_buf = &rqstp->rq_rcv_buf;
1390         u32 offset;
1391         u32 opaque_len;
1392         u32 maj_stat;
1393         int status = -EIO;
1394
1395         opaque_len = ntohl(*(*p)++);
1396         offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base;
1397         if (offset + opaque_len > rcv_buf->len)
1398                 return status;
1399         /* remove padding: */
1400         rcv_buf->len = offset + opaque_len;
1401
1402         maj_stat = gss_unwrap(ctx->gc_gss_ctx, offset, rcv_buf);
1403         if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1404                 clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1405         if (maj_stat != GSS_S_COMPLETE)
1406                 return status;
1407         if (ntohl(*(*p)++) != rqstp->rq_seqno)
1408                 return status;
1409
1410         return 0;
1411 }
1412
1413
1414 static int
1415 gss_unwrap_resp(struct rpc_task *task,
1416                 kxdrproc_t decode, void *rqstp, __be32 *p, void *obj)
1417 {
1418         struct rpc_cred *cred = task->tk_msg.rpc_cred;
1419         struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
1420                         gc_base);
1421         struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1422         __be32          *savedp = p;
1423         struct kvec     *head = ((struct rpc_rqst *)rqstp)->rq_rcv_buf.head;
1424         int             savedlen = head->iov_len;
1425         int             status = -EIO;
1426
1427         if (ctx->gc_proc != RPC_GSS_PROC_DATA)
1428                 goto out_decode;
1429         switch (gss_cred->gc_service) {
1430                 case RPC_GSS_SVC_NONE:
1431                         break;
1432                 case RPC_GSS_SVC_INTEGRITY:
1433                         status = gss_unwrap_resp_integ(cred, ctx, rqstp, &p);
1434                         if (status)
1435                                 goto out;
1436                         break;
1437                 case RPC_GSS_SVC_PRIVACY:
1438                         status = gss_unwrap_resp_priv(cred, ctx, rqstp, &p);
1439                         if (status)
1440                                 goto out;
1441                         break;
1442         }
1443         /* take into account extra slack for integrity and privacy cases: */
1444         cred->cr_auth->au_rslack = cred->cr_auth->au_verfsize + (p - savedp)
1445                                                 + (savedlen - head->iov_len);
1446 out_decode:
1447         status = decode(rqstp, p, obj);
1448 out:
1449         gss_put_ctx(ctx);
1450         dprintk("RPC: %5u gss_unwrap_resp returning %d\n", task->tk_pid,
1451                         status);
1452         return status;
1453 }
1454
1455 static const struct rpc_authops authgss_ops = {
1456         .owner          = THIS_MODULE,
1457         .au_flavor      = RPC_AUTH_GSS,
1458         .au_name        = "RPCSEC_GSS",
1459         .create         = gss_create,
1460         .destroy        = gss_destroy,
1461         .lookup_cred    = gss_lookup_cred,
1462         .crcreate       = gss_create_cred
1463 };
1464
1465 static const struct rpc_credops gss_credops = {
1466         .cr_name        = "AUTH_GSS",
1467         .crdestroy      = gss_destroy_cred,
1468         .cr_init        = gss_cred_init,
1469         .crbind         = rpcauth_generic_bind_cred,
1470         .crmatch        = gss_match,
1471         .crmarshal      = gss_marshal,
1472         .crrefresh      = gss_refresh,
1473         .crvalidate     = gss_validate,
1474         .crwrap_req     = gss_wrap_req,
1475         .crunwrap_resp  = gss_unwrap_resp,
1476 };
1477
1478 static const struct rpc_credops gss_nullops = {
1479         .cr_name        = "AUTH_GSS",
1480         .crdestroy      = gss_destroy_nullcred,
1481         .crbind         = rpcauth_generic_bind_cred,
1482         .crmatch        = gss_match,
1483         .crmarshal      = gss_marshal,
1484         .crrefresh      = gss_refresh_null,
1485         .crvalidate     = gss_validate,
1486         .crwrap_req     = gss_wrap_req,
1487         .crunwrap_resp  = gss_unwrap_resp,
1488 };
1489
1490 static struct rpc_pipe_ops gss_upcall_ops_v0 = {
1491         .upcall         = gss_pipe_upcall,
1492         .downcall       = gss_pipe_downcall,
1493         .destroy_msg    = gss_pipe_destroy_msg,
1494         .open_pipe      = gss_pipe_open_v0,
1495         .release_pipe   = gss_pipe_release,
1496 };
1497
1498 static struct rpc_pipe_ops gss_upcall_ops_v1 = {
1499         .upcall         = gss_pipe_upcall,
1500         .downcall       = gss_pipe_downcall,
1501         .destroy_msg    = gss_pipe_destroy_msg,
1502         .open_pipe      = gss_pipe_open_v1,
1503         .release_pipe   = gss_pipe_release,
1504 };
1505
1506 /*
1507  * Initialize RPCSEC_GSS module
1508  */
1509 static int __init init_rpcsec_gss(void)
1510 {
1511         int err = 0;
1512
1513         err = rpcauth_register(&authgss_ops);
1514         if (err)
1515                 goto out;
1516         err = gss_svc_init();
1517         if (err)
1518                 goto out_unregister;
1519         rpc_init_wait_queue(&pipe_version_rpc_waitqueue, "gss pipe version");
1520         return 0;
1521 out_unregister:
1522         rpcauth_unregister(&authgss_ops);
1523 out:
1524         return err;
1525 }
1526
1527 static void __exit exit_rpcsec_gss(void)
1528 {
1529         gss_svc_shutdown();
1530         rpcauth_unregister(&authgss_ops);
1531 }
1532
1533 MODULE_LICENSE("GPL");
1534 module_init(init_rpcsec_gss)
1535 module_exit(exit_rpcsec_gss)