Merge branch 'upstream-linus' of master.kernel.org:/pub/scm/linux/kernel/git/jgarzik...
[linux-2.6] / fs / dlm / lowcomms-sctp.c
1 /******************************************************************************
2 *******************************************************************************
3 **
4 **  Copyright (C) Sistina Software, Inc.  1997-2003  All rights reserved.
5 **  Copyright (C) 2004-2006 Red Hat, Inc.  All rights reserved.
6 **
7 **  This copyrighted material is made available to anyone wishing to use,
8 **  modify, copy, or redistribute it subject to the terms and conditions
9 **  of the GNU General Public License v.2.
10 **
11 *******************************************************************************
12 ******************************************************************************/
13
14 /*
15  * lowcomms.c
16  *
17  * This is the "low-level" comms layer.
18  *
19  * It is responsible for sending/receiving messages
20  * from other nodes in the cluster.
21  *
22  * Cluster nodes are referred to by their nodeids. nodeids are
23  * simply 32 bit numbers to the locking module - if they need to
24  * be expanded for the cluster infrastructure then that is it's
25  * responsibility. It is this layer's
26  * responsibility to resolve these into IP address or
27  * whatever it needs for inter-node communication.
28  *
29  * The comms level is two kernel threads that deal mainly with
30  * the receiving of messages from other nodes and passing them
31  * up to the mid-level comms layer (which understands the
32  * message format) for execution by the locking core, and
33  * a send thread which does all the setting up of connections
34  * to remote nodes and the sending of data. Threads are not allowed
35  * to send their own data because it may cause them to wait in times
36  * of high load. Also, this way, the sending thread can collect together
37  * messages bound for one node and send them in one block.
38  *
39  * I don't see any problem with the recv thread executing the locking
40  * code on behalf of remote processes as the locking code is
41  * short, efficient and never (well, hardly ever) waits.
42  *
43  */
44
45 #include <asm/ioctls.h>
46 #include <net/sock.h>
47 #include <net/tcp.h>
48 #include <net/sctp/user.h>
49 #include <linux/pagemap.h>
50 #include <linux/socket.h>
51 #include <linux/idr.h>
52
53 #include "dlm_internal.h"
54 #include "lowcomms.h"
55 #include "config.h"
56 #include "midcomms.h"
57
58 static struct sockaddr_storage *dlm_local_addr[DLM_MAX_ADDR_COUNT];
59 static int                      dlm_local_count;
60 static int                      dlm_local_nodeid;
61
62 /* One of these per connected node */
63
64 #define NI_INIT_PENDING 1
65 #define NI_WRITE_PENDING 2
66
67 struct nodeinfo {
68         spinlock_t              lock;
69         sctp_assoc_t            assoc_id;
70         unsigned long           flags;
71         struct list_head        write_list; /* nodes with pending writes */
72         struct list_head        writequeue; /* outgoing writequeue_entries */
73         spinlock_t              writequeue_lock;
74         int                     nodeid;
75 };
76
77 static DEFINE_IDR(nodeinfo_idr);
78 static DECLARE_RWSEM(nodeinfo_lock);
79 static int max_nodeid;
80
81 struct cbuf {
82         unsigned int base;
83         unsigned int len;
84         unsigned int mask;
85 };
86
87 /* Just the one of these, now. But this struct keeps
88    the connection-specific variables together */
89
90 #define CF_READ_PENDING 1
91
92 struct connection {
93         struct socket           *sock;
94         unsigned long           flags;
95         struct page             *rx_page;
96         atomic_t                waiting_requests;
97         struct cbuf             cb;
98         int                     eagain_flag;
99 };
100
101 /* An entry waiting to be sent */
102
103 struct writequeue_entry {
104         struct list_head        list;
105         struct page             *page;
106         int                     offset;
107         int                     len;
108         int                     end;
109         int                     users;
110         struct nodeinfo         *ni;
111 };
112
113 static void cbuf_add(struct cbuf *cb, int n)
114 {
115         cb->len += n;
116 }
117
118 static int cbuf_data(struct cbuf *cb)
119 {
120         return ((cb->base + cb->len) & cb->mask);
121 }
122
123 static void cbuf_init(struct cbuf *cb, int size)
124 {
125         cb->base = cb->len = 0;
126         cb->mask = size-1;
127 }
128
129 static void cbuf_eat(struct cbuf *cb, int n)
130 {
131         cb->len  -= n;
132         cb->base += n;
133         cb->base &= cb->mask;
134 }
135
136 /* List of nodes which have writes pending */
137 static LIST_HEAD(write_nodes);
138 static DEFINE_SPINLOCK(write_nodes_lock);
139
140 /* Maximum number of incoming messages to process before
141  * doing a schedule()
142  */
143 #define MAX_RX_MSG_COUNT 25
144
145 /* Manage daemons */
146 static struct task_struct *recv_task;
147 static struct task_struct *send_task;
148 static DECLARE_WAIT_QUEUE_HEAD(lowcomms_recv_wait);
149
150 /* The SCTP connection */
151 static struct connection sctp_con;
152
153
154 static int nodeid_to_addr(int nodeid, struct sockaddr *retaddr)
155 {
156         struct sockaddr_storage addr;
157         int error;
158
159         if (!dlm_local_count)
160                 return -1;
161
162         error = dlm_nodeid_to_addr(nodeid, &addr);
163         if (error)
164                 return error;
165
166         if (dlm_local_addr[0]->ss_family == AF_INET) {
167                 struct sockaddr_in *in4  = (struct sockaddr_in *) &addr;
168                 struct sockaddr_in *ret4 = (struct sockaddr_in *) retaddr;
169                 ret4->sin_addr.s_addr = in4->sin_addr.s_addr;
170         } else {
171                 struct sockaddr_in6 *in6  = (struct sockaddr_in6 *) &addr;
172                 struct sockaddr_in6 *ret6 = (struct sockaddr_in6 *) retaddr;
173                 memcpy(&ret6->sin6_addr, &in6->sin6_addr,
174                        sizeof(in6->sin6_addr));
175         }
176
177         return 0;
178 }
179
180 /* If alloc is 0 here we will not attempt to allocate a new
181    nodeinfo struct */
182 static struct nodeinfo *nodeid2nodeinfo(int nodeid, gfp_t alloc)
183 {
184         struct nodeinfo *ni;
185         int r;
186         int n;
187
188         down_read(&nodeinfo_lock);
189         ni = idr_find(&nodeinfo_idr, nodeid);
190         up_read(&nodeinfo_lock);
191
192         if (ni || !alloc)
193                 return ni;
194
195         down_write(&nodeinfo_lock);
196
197         ni = idr_find(&nodeinfo_idr, nodeid);
198         if (ni)
199                 goto out_up;
200
201         r = idr_pre_get(&nodeinfo_idr, alloc);
202         if (!r)
203                 goto out_up;
204
205         ni = kmalloc(sizeof(struct nodeinfo), alloc);
206         if (!ni)
207                 goto out_up;
208
209         r = idr_get_new_above(&nodeinfo_idr, ni, nodeid, &n);
210         if (r) {
211                 kfree(ni);
212                 ni = NULL;
213                 goto out_up;
214         }
215         if (n != nodeid) {
216                 idr_remove(&nodeinfo_idr, n);
217                 kfree(ni);
218                 ni = NULL;
219                 goto out_up;
220         }
221         memset(ni, 0, sizeof(struct nodeinfo));
222         spin_lock_init(&ni->lock);
223         INIT_LIST_HEAD(&ni->writequeue);
224         spin_lock_init(&ni->writequeue_lock);
225         ni->nodeid = nodeid;
226
227         if (nodeid > max_nodeid)
228                 max_nodeid = nodeid;
229 out_up:
230         up_write(&nodeinfo_lock);
231
232         return ni;
233 }
234
235 /* Don't call this too often... */
236 static struct nodeinfo *assoc2nodeinfo(sctp_assoc_t assoc)
237 {
238         int i;
239         struct nodeinfo *ni;
240
241         for (i=1; i<=max_nodeid; i++) {
242                 ni = nodeid2nodeinfo(i, 0);
243                 if (ni && ni->assoc_id == assoc)
244                         return ni;
245         }
246         return NULL;
247 }
248
249 /* Data or notification available on socket */
250 static void lowcomms_data_ready(struct sock *sk, int count_unused)
251 {
252         atomic_inc(&sctp_con.waiting_requests);
253         if (test_and_set_bit(CF_READ_PENDING, &sctp_con.flags))
254                 return;
255
256         wake_up_interruptible(&lowcomms_recv_wait);
257 }
258
259
260 /* Add the port number to an IP6 or 4 sockaddr and return the address length.
261    Also padd out the struct with zeros to make comparisons meaningful */
262
263 static void make_sockaddr(struct sockaddr_storage *saddr, uint16_t port,
264                           int *addr_len)
265 {
266         struct sockaddr_in *local4_addr;
267         struct sockaddr_in6 *local6_addr;
268
269         if (!dlm_local_count)
270                 return;
271
272         if (!port) {
273                 if (dlm_local_addr[0]->ss_family == AF_INET) {
274                         local4_addr = (struct sockaddr_in *)dlm_local_addr[0];
275                         port = be16_to_cpu(local4_addr->sin_port);
276                 } else {
277                         local6_addr = (struct sockaddr_in6 *)dlm_local_addr[0];
278                         port = be16_to_cpu(local6_addr->sin6_port);
279                 }
280         }
281
282         saddr->ss_family = dlm_local_addr[0]->ss_family;
283         if (dlm_local_addr[0]->ss_family == AF_INET) {
284                 struct sockaddr_in *in4_addr = (struct sockaddr_in *)saddr;
285                 in4_addr->sin_port = cpu_to_be16(port);
286                 memset(&in4_addr->sin_zero, 0, sizeof(in4_addr->sin_zero));
287                 memset(in4_addr+1, 0, sizeof(struct sockaddr_storage) -
288                        sizeof(struct sockaddr_in));
289                 *addr_len = sizeof(struct sockaddr_in);
290         } else {
291                 struct sockaddr_in6 *in6_addr = (struct sockaddr_in6 *)saddr;
292                 in6_addr->sin6_port = cpu_to_be16(port);
293                 memset(in6_addr+1, 0, sizeof(struct sockaddr_storage) -
294                        sizeof(struct sockaddr_in6));
295                 *addr_len = sizeof(struct sockaddr_in6);
296         }
297 }
298
299 /* Close the connection and tidy up */
300 static void close_connection(void)
301 {
302         if (sctp_con.sock) {
303                 sock_release(sctp_con.sock);
304                 sctp_con.sock = NULL;
305         }
306
307         if (sctp_con.rx_page) {
308                 __free_page(sctp_con.rx_page);
309                 sctp_con.rx_page = NULL;
310         }
311 }
312
313 /* We only send shutdown messages to nodes that are not part of the cluster */
314 static void send_shutdown(sctp_assoc_t associd)
315 {
316         static char outcmsg[CMSG_SPACE(sizeof(struct sctp_sndrcvinfo))];
317         struct msghdr outmessage;
318         struct cmsghdr *cmsg;
319         struct sctp_sndrcvinfo *sinfo;
320         int ret;
321
322         outmessage.msg_name = NULL;
323         outmessage.msg_namelen = 0;
324         outmessage.msg_control = outcmsg;
325         outmessage.msg_controllen = sizeof(outcmsg);
326         outmessage.msg_flags = MSG_EOR;
327
328         cmsg = CMSG_FIRSTHDR(&outmessage);
329         cmsg->cmsg_level = IPPROTO_SCTP;
330         cmsg->cmsg_type = SCTP_SNDRCV;
331         cmsg->cmsg_len = CMSG_LEN(sizeof(struct sctp_sndrcvinfo));
332         outmessage.msg_controllen = cmsg->cmsg_len;
333         sinfo = CMSG_DATA(cmsg);
334         memset(sinfo, 0x00, sizeof(struct sctp_sndrcvinfo));
335
336         sinfo->sinfo_flags |= MSG_EOF;
337         sinfo->sinfo_assoc_id = associd;
338
339         ret = kernel_sendmsg(sctp_con.sock, &outmessage, NULL, 0, 0);
340
341         if (ret != 0)
342                 log_print("send EOF to node failed: %d", ret);
343 }
344
345
346 /* INIT failed but we don't know which node...
347    restart INIT on all pending nodes */
348 static void init_failed(void)
349 {
350         int i;
351         struct nodeinfo *ni;
352
353         for (i=1; i<=max_nodeid; i++) {
354                 ni = nodeid2nodeinfo(i, 0);
355                 if (!ni)
356                         continue;
357
358                 if (test_and_clear_bit(NI_INIT_PENDING, &ni->flags)) {
359                         ni->assoc_id = 0;
360                         if (!test_and_set_bit(NI_WRITE_PENDING, &ni->flags)) {
361                                 spin_lock_bh(&write_nodes_lock);
362                                 list_add_tail(&ni->write_list, &write_nodes);
363                                 spin_unlock_bh(&write_nodes_lock);
364                         }
365                 }
366         }
367         wake_up_process(send_task);
368 }
369
370 /* Something happened to an association */
371 static void process_sctp_notification(struct msghdr *msg, char *buf)
372 {
373         union sctp_notification *sn = (union sctp_notification *)buf;
374
375         if (sn->sn_header.sn_type == SCTP_ASSOC_CHANGE) {
376                 switch (sn->sn_assoc_change.sac_state) {
377
378                 case SCTP_COMM_UP:
379                 case SCTP_RESTART:
380                 {
381                         /* Check that the new node is in the lockspace */
382                         struct sctp_prim prim;
383                         mm_segment_t fs;
384                         int nodeid;
385                         int prim_len, ret;
386                         int addr_len;
387                         struct nodeinfo *ni;
388
389                         /* This seems to happen when we received a connection
390                          * too early... or something...  anyway, it happens but
391                          * we always seem to get a real message too, see
392                          * receive_from_sock */
393
394                         if ((int)sn->sn_assoc_change.sac_assoc_id <= 0) {
395                                 log_print("COMM_UP for invalid assoc ID %d",
396                                           (int)sn->sn_assoc_change.sac_assoc_id);
397                                 init_failed();
398                                 return;
399                         }
400                         memset(&prim, 0, sizeof(struct sctp_prim));
401                         prim_len = sizeof(struct sctp_prim);
402                         prim.ssp_assoc_id = sn->sn_assoc_change.sac_assoc_id;
403
404                         fs = get_fs();
405                         set_fs(get_ds());
406                         ret = sctp_con.sock->ops->getsockopt(sctp_con.sock,
407                                                              IPPROTO_SCTP,
408                                                              SCTP_PRIMARY_ADDR,
409                                                              (char*)&prim,
410                                                              &prim_len);
411                         set_fs(fs);
412                         if (ret < 0) {
413                                 struct nodeinfo *ni;
414
415                                 log_print("getsockopt/sctp_primary_addr on "
416                                           "new assoc %d failed : %d",
417                                           (int)sn->sn_assoc_change.sac_assoc_id,
418                                           ret);
419
420                                 /* Retry INIT later */
421                                 ni = assoc2nodeinfo(sn->sn_assoc_change.sac_assoc_id);
422                                 if (ni)
423                                         clear_bit(NI_INIT_PENDING, &ni->flags);
424                                 return;
425                         }
426                         make_sockaddr(&prim.ssp_addr, 0, &addr_len);
427                         if (dlm_addr_to_nodeid(&prim.ssp_addr, &nodeid)) {
428                                 log_print("reject connect from unknown addr");
429                                 send_shutdown(prim.ssp_assoc_id);
430                                 return;
431                         }
432
433                         ni = nodeid2nodeinfo(nodeid, GFP_KERNEL);
434                         if (!ni)
435                                 return;
436
437                         /* Save the assoc ID */
438                         ni->assoc_id = sn->sn_assoc_change.sac_assoc_id;
439
440                         log_print("got new/restarted association %d nodeid %d",
441                                   (int)sn->sn_assoc_change.sac_assoc_id, nodeid);
442
443                         /* Send any pending writes */
444                         clear_bit(NI_INIT_PENDING, &ni->flags);
445                         if (!test_and_set_bit(NI_WRITE_PENDING, &ni->flags)) {
446                                 spin_lock_bh(&write_nodes_lock);
447                                 list_add_tail(&ni->write_list, &write_nodes);
448                                 spin_unlock_bh(&write_nodes_lock);
449                         }
450                         wake_up_process(send_task);
451                 }
452                 break;
453
454                 case SCTP_COMM_LOST:
455                 case SCTP_SHUTDOWN_COMP:
456                 {
457                         struct nodeinfo *ni;
458
459                         ni = assoc2nodeinfo(sn->sn_assoc_change.sac_assoc_id);
460                         if (ni) {
461                                 spin_lock(&ni->lock);
462                                 ni->assoc_id = 0;
463                                 spin_unlock(&ni->lock);
464                         }
465                 }
466                 break;
467
468                 /* We don't know which INIT failed, so clear the PENDING flags
469                  * on them all.  if assoc_id is zero then it will then try
470                  * again */
471
472                 case SCTP_CANT_STR_ASSOC:
473                 {
474                         log_print("Can't start SCTP association - retrying");
475                         init_failed();
476                 }
477                 break;
478
479                 default:
480                         log_print("unexpected SCTP assoc change id=%d state=%d",
481                                   (int)sn->sn_assoc_change.sac_assoc_id,
482                                   sn->sn_assoc_change.sac_state);
483                 }
484         }
485 }
486
487 /* Data received from remote end */
488 static int receive_from_sock(void)
489 {
490         int ret = 0;
491         struct msghdr msg;
492         struct kvec iov[2];
493         unsigned len;
494         int r;
495         struct sctp_sndrcvinfo *sinfo;
496         struct cmsghdr *cmsg;
497         struct nodeinfo *ni;
498
499         /* These two are marginally too big for stack allocation, but this
500          * function is (currently) only called by dlm_recvd so static should be
501          * OK.
502          */
503         static struct sockaddr_storage msgname;
504         static char incmsg[CMSG_SPACE(sizeof(struct sctp_sndrcvinfo))];
505
506         if (sctp_con.sock == NULL)
507                 goto out;
508
509         if (sctp_con.rx_page == NULL) {
510                 /*
511                  * This doesn't need to be atomic, but I think it should
512                  * improve performance if it is.
513                  */
514                 sctp_con.rx_page = alloc_page(GFP_ATOMIC);
515                 if (sctp_con.rx_page == NULL)
516                         goto out_resched;
517                 cbuf_init(&sctp_con.cb, PAGE_CACHE_SIZE);
518         }
519
520         memset(&incmsg, 0, sizeof(incmsg));
521         memset(&msgname, 0, sizeof(msgname));
522
523         msg.msg_name = &msgname;
524         msg.msg_namelen = sizeof(msgname);
525         msg.msg_flags = 0;
526         msg.msg_control = incmsg;
527         msg.msg_controllen = sizeof(incmsg);
528         msg.msg_iovlen = 1;
529
530         /* I don't see why this circular buffer stuff is necessary for SCTP
531          * which is a packet-based protocol, but the whole thing breaks under
532          * load without it! The overhead is minimal (and is in the TCP lowcomms
533          * anyway, of course) so I'll leave it in until I can figure out what's
534          * really happening.
535          */
536
537         /*
538          * iov[0] is the bit of the circular buffer between the current end
539          * point (cb.base + cb.len) and the end of the buffer.
540          */
541         iov[0].iov_len = sctp_con.cb.base - cbuf_data(&sctp_con.cb);
542         iov[0].iov_base = page_address(sctp_con.rx_page) +
543                 cbuf_data(&sctp_con.cb);
544         iov[1].iov_len = 0;
545
546         /*
547          * iov[1] is the bit of the circular buffer between the start of the
548          * buffer and the start of the currently used section (cb.base)
549          */
550         if (cbuf_data(&sctp_con.cb) >= sctp_con.cb.base) {
551                 iov[0].iov_len = PAGE_CACHE_SIZE - cbuf_data(&sctp_con.cb);
552                 iov[1].iov_len = sctp_con.cb.base;
553                 iov[1].iov_base = page_address(sctp_con.rx_page);
554                 msg.msg_iovlen = 2;
555         }
556         len = iov[0].iov_len + iov[1].iov_len;
557
558         r = ret = kernel_recvmsg(sctp_con.sock, &msg, iov, msg.msg_iovlen, len,
559                                  MSG_NOSIGNAL | MSG_DONTWAIT);
560         if (ret <= 0)
561                 goto out_close;
562
563         msg.msg_control = incmsg;
564         msg.msg_controllen = sizeof(incmsg);
565         cmsg = CMSG_FIRSTHDR(&msg);
566         sinfo = CMSG_DATA(cmsg);
567
568         if (msg.msg_flags & MSG_NOTIFICATION) {
569                 process_sctp_notification(&msg, page_address(sctp_con.rx_page));
570                 return 0;
571         }
572
573         /* Is this a new association ? */
574         ni = nodeid2nodeinfo(le32_to_cpu(sinfo->sinfo_ppid), GFP_KERNEL);
575         if (ni) {
576                 ni->assoc_id = sinfo->sinfo_assoc_id;
577                 if (test_and_clear_bit(NI_INIT_PENDING, &ni->flags)) {
578
579                         if (!test_and_set_bit(NI_WRITE_PENDING, &ni->flags)) {
580                                 spin_lock_bh(&write_nodes_lock);
581                                 list_add_tail(&ni->write_list, &write_nodes);
582                                 spin_unlock_bh(&write_nodes_lock);
583                         }
584                         wake_up_process(send_task);
585                 }
586         }
587
588         /* INIT sends a message with length of 1 - ignore it */
589         if (r == 1)
590                 return 0;
591
592         cbuf_add(&sctp_con.cb, ret);
593         ret = dlm_process_incoming_buffer(cpu_to_le32(sinfo->sinfo_ppid),
594                                           page_address(sctp_con.rx_page),
595                                           sctp_con.cb.base, sctp_con.cb.len,
596                                           PAGE_CACHE_SIZE);
597         if (ret < 0)
598                 goto out_close;
599         cbuf_eat(&sctp_con.cb, ret);
600
601 out:
602         ret = 0;
603         goto out_ret;
604
605 out_resched:
606         lowcomms_data_ready(sctp_con.sock->sk, 0);
607         ret = 0;
608         cond_resched();
609         goto out_ret;
610
611 out_close:
612         if (ret != -EAGAIN)
613                 log_print("error reading from sctp socket: %d", ret);
614 out_ret:
615         return ret;
616 }
617
618 /* Bind to an IP address. SCTP allows multiple address so it can do multi-homing */
619 static int add_bind_addr(struct sockaddr_storage *addr, int addr_len, int num)
620 {
621         mm_segment_t fs;
622         int result = 0;
623
624         fs = get_fs();
625         set_fs(get_ds());
626         if (num == 1)
627                 result = sctp_con.sock->ops->bind(sctp_con.sock,
628                                                   (struct sockaddr *) addr,
629                                                   addr_len);
630         else
631                 result = sctp_con.sock->ops->setsockopt(sctp_con.sock, SOL_SCTP,
632                                                         SCTP_SOCKOPT_BINDX_ADD,
633                                                         (char *)addr, addr_len);
634         set_fs(fs);
635
636         if (result < 0)
637                 log_print("Can't bind to port %d addr number %d",
638                           dlm_config.tcp_port, num);
639
640         return result;
641 }
642
643 static void init_local(void)
644 {
645         struct sockaddr_storage sas, *addr;
646         int i;
647
648         dlm_local_nodeid = dlm_our_nodeid();
649
650         for (i = 0; i < DLM_MAX_ADDR_COUNT - 1; i++) {
651                 if (dlm_our_addr(&sas, i))
652                         break;
653
654                 addr = kmalloc(sizeof(*addr), GFP_KERNEL);
655                 if (!addr)
656                         break;
657                 memcpy(addr, &sas, sizeof(*addr));
658                 dlm_local_addr[dlm_local_count++] = addr;
659         }
660 }
661
662 /* Initialise SCTP socket and bind to all interfaces */
663 static int init_sock(void)
664 {
665         mm_segment_t fs;
666         struct socket *sock = NULL;
667         struct sockaddr_storage localaddr;
668         struct sctp_event_subscribe subscribe;
669         int result = -EINVAL, num = 1, i, addr_len;
670
671         if (!dlm_local_count) {
672                 init_local();
673                 if (!dlm_local_count) {
674                         log_print("no local IP address has been set");
675                         goto out;
676                 }
677         }
678
679         result = sock_create_kern(dlm_local_addr[0]->ss_family, SOCK_SEQPACKET,
680                                   IPPROTO_SCTP, &sock);
681         if (result < 0) {
682                 log_print("Can't create comms socket, check SCTP is loaded");
683                 goto out;
684         }
685
686         /* Listen for events */
687         memset(&subscribe, 0, sizeof(subscribe));
688         subscribe.sctp_data_io_event = 1;
689         subscribe.sctp_association_event = 1;
690         subscribe.sctp_send_failure_event = 1;
691         subscribe.sctp_shutdown_event = 1;
692         subscribe.sctp_partial_delivery_event = 1;
693
694         fs = get_fs();
695         set_fs(get_ds());
696         result = sock->ops->setsockopt(sock, SOL_SCTP, SCTP_EVENTS,
697                                        (char *)&subscribe, sizeof(subscribe));
698         set_fs(fs);
699
700         if (result < 0) {
701                 log_print("Failed to set SCTP_EVENTS on socket: result=%d",
702                           result);
703                 goto create_delsock;
704         }
705
706         /* Init con struct */
707         sock->sk->sk_user_data = &sctp_con;
708         sctp_con.sock = sock;
709         sctp_con.sock->sk->sk_data_ready = lowcomms_data_ready;
710
711         /* Bind to all interfaces. */
712         for (i = 0; i < dlm_local_count; i++) {
713                 memcpy(&localaddr, dlm_local_addr[i], sizeof(localaddr));
714                 make_sockaddr(&localaddr, dlm_config.tcp_port, &addr_len);
715
716                 result = add_bind_addr(&localaddr, addr_len, num);
717                 if (result)
718                         goto create_delsock;
719                 ++num;
720         }
721
722         result = sock->ops->listen(sock, 5);
723         if (result < 0) {
724                 log_print("Can't set socket listening");
725                 goto create_delsock;
726         }
727
728         return 0;
729
730 create_delsock:
731         sock_release(sock);
732         sctp_con.sock = NULL;
733 out:
734         return result;
735 }
736
737
738 static struct writequeue_entry *new_writequeue_entry(gfp_t allocation)
739 {
740         struct writequeue_entry *entry;
741
742         entry = kmalloc(sizeof(struct writequeue_entry), allocation);
743         if (!entry)
744                 return NULL;
745
746         entry->page = alloc_page(allocation);
747         if (!entry->page) {
748                 kfree(entry);
749                 return NULL;
750         }
751
752         entry->offset = 0;
753         entry->len = 0;
754         entry->end = 0;
755         entry->users = 0;
756
757         return entry;
758 }
759
760 void *dlm_lowcomms_get_buffer(int nodeid, int len, gfp_t allocation, char **ppc)
761 {
762         struct writequeue_entry *e;
763         int offset = 0;
764         int users = 0;
765         struct nodeinfo *ni;
766
767         ni = nodeid2nodeinfo(nodeid, allocation);
768         if (!ni)
769                 return NULL;
770
771         spin_lock(&ni->writequeue_lock);
772         e = list_entry(ni->writequeue.prev, struct writequeue_entry, list);
773         if ((&e->list == &ni->writequeue) ||
774             (PAGE_CACHE_SIZE - e->end < len)) {
775                 e = NULL;
776         } else {
777                 offset = e->end;
778                 e->end += len;
779                 users = e->users++;
780         }
781         spin_unlock(&ni->writequeue_lock);
782
783         if (e) {
784         got_one:
785                 if (users == 0)
786                         kmap(e->page);
787                 *ppc = page_address(e->page) + offset;
788                 return e;
789         }
790
791         e = new_writequeue_entry(allocation);
792         if (e) {
793                 spin_lock(&ni->writequeue_lock);
794                 offset = e->end;
795                 e->end += len;
796                 e->ni = ni;
797                 users = e->users++;
798                 list_add_tail(&e->list, &ni->writequeue);
799                 spin_unlock(&ni->writequeue_lock);
800                 goto got_one;
801         }
802         return NULL;
803 }
804
805 void dlm_lowcomms_commit_buffer(void *arg)
806 {
807         struct writequeue_entry *e = (struct writequeue_entry *) arg;
808         int users;
809         struct nodeinfo *ni = e->ni;
810
811         spin_lock(&ni->writequeue_lock);
812         users = --e->users;
813         if (users)
814                 goto out;
815         e->len = e->end - e->offset;
816         kunmap(e->page);
817         spin_unlock(&ni->writequeue_lock);
818
819         if (!test_and_set_bit(NI_WRITE_PENDING, &ni->flags)) {
820                 spin_lock_bh(&write_nodes_lock);
821                 list_add_tail(&ni->write_list, &write_nodes);
822                 spin_unlock_bh(&write_nodes_lock);
823                 wake_up_process(send_task);
824         }
825         return;
826
827 out:
828         spin_unlock(&ni->writequeue_lock);
829         return;
830 }
831
832 static void free_entry(struct writequeue_entry *e)
833 {
834         __free_page(e->page);
835         kfree(e);
836 }
837
838 /* Initiate an SCTP association. In theory we could just use sendmsg() on
839    the first IP address and it should work, but this allows us to set up the
840    association before sending any valuable data that we can't afford to lose.
841    It also keeps the send path clean as it can now always use the association ID */
842 static void initiate_association(int nodeid)
843 {
844         struct sockaddr_storage rem_addr;
845         static char outcmsg[CMSG_SPACE(sizeof(struct sctp_sndrcvinfo))];
846         struct msghdr outmessage;
847         struct cmsghdr *cmsg;
848         struct sctp_sndrcvinfo *sinfo;
849         int ret;
850         int addrlen;
851         char buf[1];
852         struct kvec iov[1];
853         struct nodeinfo *ni;
854
855         log_print("Initiating association with node %d", nodeid);
856
857         ni = nodeid2nodeinfo(nodeid, GFP_KERNEL);
858         if (!ni)
859                 return;
860
861         if (nodeid_to_addr(nodeid, (struct sockaddr *)&rem_addr)) {
862                 log_print("no address for nodeid %d", nodeid);
863                 return;
864         }
865
866         make_sockaddr(&rem_addr, dlm_config.tcp_port, &addrlen);
867
868         outmessage.msg_name = &rem_addr;
869         outmessage.msg_namelen = addrlen;
870         outmessage.msg_control = outcmsg;
871         outmessage.msg_controllen = sizeof(outcmsg);
872         outmessage.msg_flags = MSG_EOR;
873
874         iov[0].iov_base = buf;
875         iov[0].iov_len = 1;
876
877         /* Real INIT messages seem to cause trouble. Just send a 1 byte message
878            we can afford to lose */
879         cmsg = CMSG_FIRSTHDR(&outmessage);
880         cmsg->cmsg_level = IPPROTO_SCTP;
881         cmsg->cmsg_type = SCTP_SNDRCV;
882         cmsg->cmsg_len = CMSG_LEN(sizeof(struct sctp_sndrcvinfo));
883         sinfo = CMSG_DATA(cmsg);
884         memset(sinfo, 0x00, sizeof(struct sctp_sndrcvinfo));
885         sinfo->sinfo_ppid = cpu_to_le32(dlm_local_nodeid);
886
887         outmessage.msg_controllen = cmsg->cmsg_len;
888         ret = kernel_sendmsg(sctp_con.sock, &outmessage, iov, 1, 1);
889         if (ret < 0) {
890                 log_print("send INIT to node failed: %d", ret);
891                 /* Try again later */
892                 clear_bit(NI_INIT_PENDING, &ni->flags);
893         }
894 }
895
896 /* Send a message */
897 static void send_to_sock(struct nodeinfo *ni)
898 {
899         int ret = 0;
900         struct writequeue_entry *e;
901         int len, offset;
902         struct msghdr outmsg;
903         static char outcmsg[CMSG_SPACE(sizeof(struct sctp_sndrcvinfo))];
904         struct cmsghdr *cmsg;
905         struct sctp_sndrcvinfo *sinfo;
906         struct kvec iov;
907
908         /* See if we need to init an association before we start
909            sending precious messages */
910         spin_lock(&ni->lock);
911         if (!ni->assoc_id && !test_and_set_bit(NI_INIT_PENDING, &ni->flags)) {
912                 spin_unlock(&ni->lock);
913                 initiate_association(ni->nodeid);
914                 return;
915         }
916         spin_unlock(&ni->lock);
917
918         outmsg.msg_name = NULL; /* We use assoc_id */
919         outmsg.msg_namelen = 0;
920         outmsg.msg_control = outcmsg;
921         outmsg.msg_controllen = sizeof(outcmsg);
922         outmsg.msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL | MSG_EOR;
923
924         cmsg = CMSG_FIRSTHDR(&outmsg);
925         cmsg->cmsg_level = IPPROTO_SCTP;
926         cmsg->cmsg_type = SCTP_SNDRCV;
927         cmsg->cmsg_len = CMSG_LEN(sizeof(struct sctp_sndrcvinfo));
928         sinfo = CMSG_DATA(cmsg);
929         memset(sinfo, 0x00, sizeof(struct sctp_sndrcvinfo));
930         sinfo->sinfo_ppid = cpu_to_le32(dlm_local_nodeid);
931         sinfo->sinfo_assoc_id = ni->assoc_id;
932         outmsg.msg_controllen = cmsg->cmsg_len;
933
934         spin_lock(&ni->writequeue_lock);
935         for (;;) {
936                 if (list_empty(&ni->writequeue))
937                         break;
938                 e = list_entry(ni->writequeue.next, struct writequeue_entry,
939                                list);
940                 len = e->len;
941                 offset = e->offset;
942                 BUG_ON(len == 0 && e->users == 0);
943                 spin_unlock(&ni->writequeue_lock);
944                 kmap(e->page);
945
946                 ret = 0;
947                 if (len) {
948                         iov.iov_base = page_address(e->page)+offset;
949                         iov.iov_len = len;
950
951                         ret = kernel_sendmsg(sctp_con.sock, &outmsg, &iov, 1,
952                                              len);
953                         if (ret == -EAGAIN) {
954                                 sctp_con.eagain_flag = 1;
955                                 goto out;
956                         } else if (ret < 0)
957                                 goto send_error;
958                 } else {
959                         /* Don't starve people filling buffers */
960                         cond_resched();
961                 }
962
963                 spin_lock(&ni->writequeue_lock);
964                 e->offset += ret;
965                 e->len -= ret;
966
967                 if (e->len == 0 && e->users == 0) {
968                         list_del(&e->list);
969                         kunmap(e->page);
970                         free_entry(e);
971                         continue;
972                 }
973         }
974         spin_unlock(&ni->writequeue_lock);
975 out:
976         return;
977
978 send_error:
979         log_print("Error sending to node %d %d", ni->nodeid, ret);
980         spin_lock(&ni->lock);
981         if (!test_and_set_bit(NI_INIT_PENDING, &ni->flags)) {
982                 ni->assoc_id = 0;
983                 spin_unlock(&ni->lock);
984                 initiate_association(ni->nodeid);
985         } else
986                 spin_unlock(&ni->lock);
987
988         return;
989 }
990
991 /* Try to send any messages that are pending */
992 static void process_output_queue(void)
993 {
994         struct list_head *list;
995         struct list_head *temp;
996
997         spin_lock_bh(&write_nodes_lock);
998         list_for_each_safe(list, temp, &write_nodes) {
999                 struct nodeinfo *ni =
1000                         list_entry(list, struct nodeinfo, write_list);
1001                 clear_bit(NI_WRITE_PENDING, &ni->flags);
1002                 list_del(&ni->write_list);
1003
1004                 spin_unlock_bh(&write_nodes_lock);
1005
1006                 send_to_sock(ni);
1007                 spin_lock_bh(&write_nodes_lock);
1008         }
1009         spin_unlock_bh(&write_nodes_lock);
1010 }
1011
1012 /* Called after we've had -EAGAIN and been woken up */
1013 static void refill_write_queue(void)
1014 {
1015         int i;
1016
1017         for (i=1; i<=max_nodeid; i++) {
1018                 struct nodeinfo *ni = nodeid2nodeinfo(i, 0);
1019
1020                 if (ni) {
1021                         if (!test_and_set_bit(NI_WRITE_PENDING, &ni->flags)) {
1022                                 spin_lock_bh(&write_nodes_lock);
1023                                 list_add_tail(&ni->write_list, &write_nodes);
1024                                 spin_unlock_bh(&write_nodes_lock);
1025                         }
1026                 }
1027         }
1028 }
1029
1030 static void clean_one_writequeue(struct nodeinfo *ni)
1031 {
1032         struct list_head *list;
1033         struct list_head *temp;
1034
1035         spin_lock(&ni->writequeue_lock);
1036         list_for_each_safe(list, temp, &ni->writequeue) {
1037                 struct writequeue_entry *e =
1038                         list_entry(list, struct writequeue_entry, list);
1039                 list_del(&e->list);
1040                 free_entry(e);
1041         }
1042         spin_unlock(&ni->writequeue_lock);
1043 }
1044
1045 static void clean_writequeues(void)
1046 {
1047         int i;
1048
1049         for (i=1; i<=max_nodeid; i++) {
1050                 struct nodeinfo *ni = nodeid2nodeinfo(i, 0);
1051                 if (ni)
1052                         clean_one_writequeue(ni);
1053         }
1054 }
1055
1056
1057 static void dealloc_nodeinfo(void)
1058 {
1059         int i;
1060
1061         for (i=1; i<=max_nodeid; i++) {
1062                 struct nodeinfo *ni = nodeid2nodeinfo(i, 0);
1063                 if (ni) {
1064                         idr_remove(&nodeinfo_idr, i);
1065                         kfree(ni);
1066                 }
1067         }
1068 }
1069
1070 int dlm_lowcomms_close(int nodeid)
1071 {
1072         struct nodeinfo *ni;
1073
1074         ni = nodeid2nodeinfo(nodeid, 0);
1075         if (!ni)
1076                 return -1;
1077
1078         spin_lock(&ni->lock);
1079         if (ni->assoc_id) {
1080                 ni->assoc_id = 0;
1081                 /* Don't send shutdown here, sctp will just queue it
1082                    till the node comes back up! */
1083         }
1084         spin_unlock(&ni->lock);
1085
1086         clean_one_writequeue(ni);
1087         clear_bit(NI_INIT_PENDING, &ni->flags);
1088         return 0;
1089 }
1090
1091 static int write_list_empty(void)
1092 {
1093         int status;
1094
1095         spin_lock_bh(&write_nodes_lock);
1096         status = list_empty(&write_nodes);
1097         spin_unlock_bh(&write_nodes_lock);
1098
1099         return status;
1100 }
1101
1102 static int dlm_recvd(void *data)
1103 {
1104         DECLARE_WAITQUEUE(wait, current);
1105
1106         while (!kthread_should_stop()) {
1107                 int count = 0;
1108
1109                 set_current_state(TASK_INTERRUPTIBLE);
1110                 add_wait_queue(&lowcomms_recv_wait, &wait);
1111                 if (!test_bit(CF_READ_PENDING, &sctp_con.flags))
1112                         cond_resched();
1113                 remove_wait_queue(&lowcomms_recv_wait, &wait);
1114                 set_current_state(TASK_RUNNING);
1115
1116                 if (test_and_clear_bit(CF_READ_PENDING, &sctp_con.flags)) {
1117                         int ret;
1118
1119                         do {
1120                                 ret = receive_from_sock();
1121
1122                                 /* Don't starve out everyone else */
1123                                 if (++count >= MAX_RX_MSG_COUNT) {
1124                                         cond_resched();
1125                                         count = 0;
1126                                 }
1127                         } while (!kthread_should_stop() && ret >=0);
1128                 }
1129                 cond_resched();
1130         }
1131
1132         return 0;
1133 }
1134
1135 static int dlm_sendd(void *data)
1136 {
1137         DECLARE_WAITQUEUE(wait, current);
1138
1139         add_wait_queue(sctp_con.sock->sk->sk_sleep, &wait);
1140
1141         while (!kthread_should_stop()) {
1142                 set_current_state(TASK_INTERRUPTIBLE);
1143                 if (write_list_empty())
1144                         cond_resched();
1145                 set_current_state(TASK_RUNNING);
1146
1147                 if (sctp_con.eagain_flag) {
1148                         sctp_con.eagain_flag = 0;
1149                         refill_write_queue();
1150                 }
1151                 process_output_queue();
1152         }
1153
1154         remove_wait_queue(sctp_con.sock->sk->sk_sleep, &wait);
1155
1156         return 0;
1157 }
1158
1159 static void daemons_stop(void)
1160 {
1161         kthread_stop(recv_task);
1162         kthread_stop(send_task);
1163 }
1164
1165 static int daemons_start(void)
1166 {
1167         struct task_struct *p;
1168         int error;
1169
1170         p = kthread_run(dlm_recvd, NULL, "dlm_recvd");
1171         error = IS_ERR(p);
1172         if (error) {
1173                 log_print("can't start dlm_recvd %d", error);
1174                 return error;
1175         }
1176         recv_task = p;
1177
1178         p = kthread_run(dlm_sendd, NULL, "dlm_sendd");
1179         error = IS_ERR(p);
1180         if (error) {
1181                 log_print("can't start dlm_sendd %d", error);
1182                 kthread_stop(recv_task);
1183                 return error;
1184         }
1185         send_task = p;
1186
1187         return 0;
1188 }
1189
1190 /*
1191  * This is quite likely to sleep...
1192  */
1193 int dlm_lowcomms_start(void)
1194 {
1195         int error;
1196
1197         error = init_sock();
1198         if (error)
1199                 goto fail_sock;
1200         error = daemons_start();
1201         if (error)
1202                 goto fail_sock;
1203         return 0;
1204
1205 fail_sock:
1206         close_connection();
1207         return error;
1208 }
1209
1210 void dlm_lowcomms_stop(void)
1211 {
1212         int i;
1213
1214         sctp_con.flags = 0x7;
1215         daemons_stop();
1216         clean_writequeues();
1217         close_connection();
1218         dealloc_nodeinfo();
1219         max_nodeid = 0;
1220
1221         dlm_local_count = 0;
1222         dlm_local_nodeid = 0;
1223
1224         for (i = 0; i < dlm_local_count; i++)
1225                 kfree(dlm_local_addr[i]);
1226 }
1227