Merge git://git.kernel.org/pub/scm/linux/kernel/git/steve/gfs2-2.6
[linux-2.6] / fs / dlm / lowcomms.c
1 /******************************************************************************
2 *******************************************************************************
3 **
4 **  Copyright (C) Sistina Software, Inc.  1997-2003  All rights reserved.
5 **  Copyright (C) 2004-2005 Red Hat, Inc.  All rights reserved.
6 **
7 **  This copyrighted material is made available to anyone wishing to use,
8 **  modify, copy, or redistribute it subject to the terms and conditions
9 **  of the GNU General Public License v.2.
10 **
11 *******************************************************************************
12 ******************************************************************************/
13
14 /*
15  * lowcomms.c
16  *
17  * This is the "low-level" comms layer.
18  *
19  * It is responsible for sending/receiving messages
20  * from other nodes in the cluster.
21  *
22  * Cluster nodes are referred to by their nodeids. nodeids are
23  * simply 32 bit numbers to the locking module - if they need to
24  * be expanded for the cluster infrastructure then that is it's
25  * responsibility. It is this layer's
26  * responsibility to resolve these into IP address or
27  * whatever it needs for inter-node communication.
28  *
29  * The comms level is two kernel threads that deal mainly with
30  * the receiving of messages from other nodes and passing them
31  * up to the mid-level comms layer (which understands the
32  * message format) for execution by the locking core, and
33  * a send thread which does all the setting up of connections
34  * to remote nodes and the sending of data. Threads are not allowed
35  * to send their own data because it may cause them to wait in times
36  * of high load. Also, this way, the sending thread can collect together
37  * messages bound for one node and send them in one block.
38  *
39  * I don't see any problem with the recv thread executing the locking
40  * code on behalf of remote processes as the locking code is
41  * short, efficient and never (well, hardly ever) waits.
42  *
43  */
44
45 #include <asm/ioctls.h>
46 #include <net/sock.h>
47 #include <net/tcp.h>
48 #include <net/sctp/user.h>
49 #include <linux/pagemap.h>
50 #include <linux/socket.h>
51 #include <linux/idr.h>
52
53 #include "dlm_internal.h"
54 #include "lowcomms.h"
55 #include "config.h"
56 #include "midcomms.h"
57
58 static struct sockaddr_storage *dlm_local_addr[DLM_MAX_ADDR_COUNT];
59 static int                      dlm_local_count;
60 static int                      dlm_local_nodeid;
61
62 /* One of these per connected node */
63
64 #define NI_INIT_PENDING 1
65 #define NI_WRITE_PENDING 2
66
67 struct nodeinfo {
68         spinlock_t              lock;
69         sctp_assoc_t            assoc_id;
70         unsigned long           flags;
71         struct list_head        write_list; /* nodes with pending writes */
72         struct list_head        writequeue; /* outgoing writequeue_entries */
73         spinlock_t              writequeue_lock;
74         int                     nodeid;
75 };
76
77 static DEFINE_IDR(nodeinfo_idr);
78 static struct rw_semaphore      nodeinfo_lock;
79 static int                      max_nodeid;
80
81 struct cbuf {
82         unsigned                base;
83         unsigned                len;
84         unsigned                mask;
85 };
86
87 /* Just the one of these, now. But this struct keeps
88    the connection-specific variables together */
89
90 #define CF_READ_PENDING 1
91
92 struct connection {
93         struct socket          *sock;
94         unsigned long           flags;
95         struct page            *rx_page;
96         atomic_t                waiting_requests;
97         struct cbuf             cb;
98         int                     eagain_flag;
99 };
100
101 /* An entry waiting to be sent */
102
103 struct writequeue_entry {
104         struct list_head        list;
105         struct page            *page;
106         int                     offset;
107         int                     len;
108         int                     end;
109         int                     users;
110         struct nodeinfo        *ni;
111 };
112
113 #define CBUF_ADD(cb, n) do { (cb)->len += n; } while(0)
114 #define CBUF_EMPTY(cb) ((cb)->len == 0)
115 #define CBUF_MAY_ADD(cb, n) (((cb)->len + (n)) < ((cb)->mask + 1))
116 #define CBUF_DATA(cb) (((cb)->base + (cb)->len) & (cb)->mask)
117
118 #define CBUF_INIT(cb, size) \
119 do { \
120         (cb)->base = (cb)->len = 0; \
121         (cb)->mask = ((size)-1); \
122 } while(0)
123
124 #define CBUF_EAT(cb, n) \
125 do { \
126         (cb)->len  -= (n); \
127         (cb)->base += (n); \
128         (cb)->base &= (cb)->mask; \
129 } while(0)
130
131
132 /* List of nodes which have writes pending */
133 static struct list_head write_nodes;
134 static spinlock_t write_nodes_lock;
135
136 /* Maximum number of incoming messages to process before
137  * doing a schedule()
138  */
139 #define MAX_RX_MSG_COUNT 25
140
141 /* Manage daemons */
142 static struct task_struct *recv_task;
143 static struct task_struct *send_task;
144 static wait_queue_head_t lowcomms_recv_wait;
145 static atomic_t accepting;
146
147 /* The SCTP connection */
148 static struct connection sctp_con;
149
150
151 static int nodeid_to_addr(int nodeid, struct sockaddr *retaddr)
152 {
153         struct sockaddr_storage addr;
154         int error;
155
156         if (!dlm_local_count)
157                 return -1;
158
159         error = dlm_nodeid_to_addr(nodeid, &addr);
160         if (error)
161                 return error;
162
163         if (dlm_local_addr[0]->ss_family == AF_INET) {
164                 struct sockaddr_in *in4  = (struct sockaddr_in *) &addr;
165                 struct sockaddr_in *ret4 = (struct sockaddr_in *) retaddr;
166                 ret4->sin_addr.s_addr = in4->sin_addr.s_addr;
167         } else {
168                 struct sockaddr_in6 *in6  = (struct sockaddr_in6 *) &addr;
169                 struct sockaddr_in6 *ret6 = (struct sockaddr_in6 *) retaddr;
170                 memcpy(&ret6->sin6_addr, &in6->sin6_addr,
171                        sizeof(in6->sin6_addr));
172         }
173
174         return 0;
175 }
176
177 static struct nodeinfo *nodeid2nodeinfo(int nodeid, int alloc)
178 {
179         struct nodeinfo *ni;
180         int r;
181         int n;
182
183         down_read(&nodeinfo_lock);
184         ni = idr_find(&nodeinfo_idr, nodeid);
185         up_read(&nodeinfo_lock);
186
187         if (!ni && alloc) {
188                 down_write(&nodeinfo_lock);
189
190                 ni = idr_find(&nodeinfo_idr, nodeid);
191                 if (ni)
192                         goto out_up;
193
194                 r = idr_pre_get(&nodeinfo_idr, alloc);
195                 if (!r)
196                         goto out_up;
197
198                 ni = kmalloc(sizeof(struct nodeinfo), alloc);
199                 if (!ni)
200                         goto out_up;
201
202                 r = idr_get_new_above(&nodeinfo_idr, ni, nodeid, &n);
203                 if (r) {
204                         kfree(ni);
205                         ni = NULL;
206                         goto out_up;
207                 }
208                 if (n != nodeid) {
209                         idr_remove(&nodeinfo_idr, n);
210                         kfree(ni);
211                         ni = NULL;
212                         goto out_up;
213                 }
214                 memset(ni, 0, sizeof(struct nodeinfo));
215                 spin_lock_init(&ni->lock);
216                 INIT_LIST_HEAD(&ni->writequeue);
217                 spin_lock_init(&ni->writequeue_lock);
218                 ni->nodeid = nodeid;
219
220                 if (nodeid > max_nodeid)
221                         max_nodeid = nodeid;
222         out_up:
223                 up_write(&nodeinfo_lock);
224         }
225
226         return ni;
227 }
228
229 /* Don't call this too often... */
230 static struct nodeinfo *assoc2nodeinfo(sctp_assoc_t assoc)
231 {
232         int i;
233         struct nodeinfo *ni;
234
235         for (i=1; i<=max_nodeid; i++) {
236                 ni = nodeid2nodeinfo(i, 0);
237                 if (ni && ni->assoc_id == assoc)
238                         return ni;
239         }
240         return NULL;
241 }
242
243 /* Data or notification available on socket */
244 static void lowcomms_data_ready(struct sock *sk, int count_unused)
245 {
246         atomic_inc(&sctp_con.waiting_requests);
247         if (test_and_set_bit(CF_READ_PENDING, &sctp_con.flags))
248                 return;
249
250         wake_up_interruptible(&lowcomms_recv_wait);
251 }
252
253
254 /* Add the port number to an IP6 or 4 sockaddr and return the address length.
255    Also padd out the struct with zeros to make comparisons meaningful */
256
257 static void make_sockaddr(struct sockaddr_storage *saddr, uint16_t port,
258                           int *addr_len)
259 {
260         struct sockaddr_in *local4_addr;
261         struct sockaddr_in6 *local6_addr;
262
263         if (!dlm_local_count)
264                 return;
265
266         if (!port) {
267                 if (dlm_local_addr[0]->ss_family == AF_INET) {
268                         local4_addr = (struct sockaddr_in *)dlm_local_addr[0];
269                         port = be16_to_cpu(local4_addr->sin_port);
270                 } else {
271                         local6_addr = (struct sockaddr_in6 *)dlm_local_addr[0];
272                         port = be16_to_cpu(local6_addr->sin6_port);
273                 }
274         }
275
276         saddr->ss_family = dlm_local_addr[0]->ss_family;
277         if (dlm_local_addr[0]->ss_family == AF_INET) {
278                 struct sockaddr_in *in4_addr = (struct sockaddr_in *)saddr;
279                 in4_addr->sin_port = cpu_to_be16(port);
280                 memset(&in4_addr->sin_zero, 0, sizeof(in4_addr->sin_zero));
281                 memset(in4_addr+1, 0, sizeof(struct sockaddr_storage) -
282                                       sizeof(struct sockaddr_in));
283                 *addr_len = sizeof(struct sockaddr_in);
284         } else {
285                 struct sockaddr_in6 *in6_addr = (struct sockaddr_in6 *)saddr;
286                 in6_addr->sin6_port = cpu_to_be16(port);
287                 memset(in6_addr+1, 0, sizeof(struct sockaddr_storage) -
288                                       sizeof(struct sockaddr_in6));
289                 *addr_len = sizeof(struct sockaddr_in6);
290         }
291 }
292
293 /* Close the connection and tidy up */
294 static void close_connection(void)
295 {
296         if (sctp_con.sock) {
297                 sock_release(sctp_con.sock);
298                 sctp_con.sock = NULL;
299         }
300
301         if (sctp_con.rx_page) {
302                 __free_page(sctp_con.rx_page);
303                 sctp_con.rx_page = NULL;
304         }
305 }
306
307 /* We only send shutdown messages to nodes that are not part of the cluster */
308 static void send_shutdown(sctp_assoc_t associd)
309 {
310         static char outcmsg[CMSG_SPACE(sizeof(struct sctp_sndrcvinfo))];
311         struct msghdr outmessage;
312         struct cmsghdr *cmsg;
313         struct sctp_sndrcvinfo *sinfo;
314         int ret;
315
316         outmessage.msg_name = NULL;
317         outmessage.msg_namelen = 0;
318         outmessage.msg_control = outcmsg;
319         outmessage.msg_controllen = sizeof(outcmsg);
320         outmessage.msg_flags = MSG_EOR;
321
322         cmsg = CMSG_FIRSTHDR(&outmessage);
323         cmsg->cmsg_level = IPPROTO_SCTP;
324         cmsg->cmsg_type = SCTP_SNDRCV;
325         cmsg->cmsg_len = CMSG_LEN(sizeof(struct sctp_sndrcvinfo));
326         outmessage.msg_controllen = cmsg->cmsg_len;
327         sinfo = (struct sctp_sndrcvinfo *)CMSG_DATA(cmsg);
328         memset(sinfo, 0x00, sizeof(struct sctp_sndrcvinfo));
329
330         sinfo->sinfo_flags |= MSG_EOF;
331         sinfo->sinfo_assoc_id = associd;
332
333         ret = kernel_sendmsg(sctp_con.sock, &outmessage, NULL, 0, 0);
334
335         if (ret != 0)
336                 log_print("send EOF to node failed: %d", ret);
337 }
338
339
340 /* INIT failed but we don't know which node...
341    restart INIT on all pending nodes */
342 static void init_failed(void)
343 {
344         int i;
345         struct nodeinfo *ni;
346
347         for (i=1; i<=max_nodeid; i++) {
348                 ni = nodeid2nodeinfo(i, 0);
349                 if (!ni)
350                         continue;
351
352                 if (test_and_clear_bit(NI_INIT_PENDING, &ni->flags)) {
353                         ni->assoc_id = 0;
354                         if (!test_and_set_bit(NI_WRITE_PENDING, &ni->flags)) {
355                                 spin_lock_bh(&write_nodes_lock);
356                                 list_add_tail(&ni->write_list, &write_nodes);
357                                 spin_unlock_bh(&write_nodes_lock);
358                         }
359                 }
360         }
361         wake_up_process(send_task);
362 }
363
364 /* Something happened to an association */
365 static void process_sctp_notification(struct msghdr *msg, char *buf)
366 {
367         union sctp_notification *sn = (union sctp_notification *)buf;
368
369         if (sn->sn_header.sn_type == SCTP_ASSOC_CHANGE) {
370                 switch (sn->sn_assoc_change.sac_state) {
371
372                 case SCTP_COMM_UP:
373                 case SCTP_RESTART:
374                 {
375                         /* Check that the new node is in the lockspace */
376                         struct sctp_prim prim;
377                         mm_segment_t fs;
378                         int nodeid;
379                         int prim_len, ret;
380                         int addr_len;
381                         struct nodeinfo *ni;
382
383                         /* This seems to happen when we received a connection
384                          * too early... or something...  anyway, it happens but
385                          * we always seem to get a real message too, see
386                          * receive_from_sock */
387
388                         if ((int)sn->sn_assoc_change.sac_assoc_id <= 0) {
389                                 log_print("COMM_UP for invalid assoc ID %d",
390                                          (int)sn->sn_assoc_change.sac_assoc_id);
391                                 init_failed();
392                                 return;
393                         }
394                         memset(&prim, 0, sizeof(struct sctp_prim));
395                         prim_len = sizeof(struct sctp_prim);
396                         prim.ssp_assoc_id = sn->sn_assoc_change.sac_assoc_id;
397
398                         fs = get_fs();
399                         set_fs(get_ds());
400                         ret = sctp_con.sock->ops->getsockopt(sctp_con.sock,
401                                                 IPPROTO_SCTP, SCTP_PRIMARY_ADDR,
402                                                 (char*)&prim, &prim_len);
403                         set_fs(fs);
404                         if (ret < 0) {
405                                 struct nodeinfo *ni;
406
407                                 log_print("getsockopt/sctp_primary_addr on "
408                                           "new assoc %d failed : %d",
409                                     (int)sn->sn_assoc_change.sac_assoc_id, ret);
410
411                                 /* Retry INIT later */
412                                 ni = assoc2nodeinfo(sn->sn_assoc_change.sac_assoc_id);
413                                 if (ni)
414                                         clear_bit(NI_INIT_PENDING, &ni->flags);
415                                 return;
416                         }
417                         make_sockaddr(&prim.ssp_addr, 0, &addr_len);
418                         if (dlm_addr_to_nodeid(&prim.ssp_addr, &nodeid)) {
419                                 log_print("reject connect from unknown addr");
420                                 send_shutdown(prim.ssp_assoc_id);
421                                 return;
422                         }
423
424                         ni = nodeid2nodeinfo(nodeid, GFP_KERNEL);
425                         if (!ni)
426                                 return;
427
428                         /* Save the assoc ID */
429                         spin_lock(&ni->lock);
430                         ni->assoc_id = sn->sn_assoc_change.sac_assoc_id;
431                         spin_unlock(&ni->lock);
432
433                         log_print("got new/restarted association %d nodeid %d",
434                                (int)sn->sn_assoc_change.sac_assoc_id, nodeid);
435
436                         /* Send any pending writes */
437                         clear_bit(NI_INIT_PENDING, &ni->flags);
438                         if (!test_and_set_bit(NI_WRITE_PENDING, &ni->flags)) {
439                                 spin_lock_bh(&write_nodes_lock);
440                                 list_add_tail(&ni->write_list, &write_nodes);
441                                 spin_unlock_bh(&write_nodes_lock);
442                         }
443                         wake_up_process(send_task);
444                 }
445                 break;
446
447                 case SCTP_COMM_LOST:
448                 case SCTP_SHUTDOWN_COMP:
449                 {
450                         struct nodeinfo *ni;
451
452                         ni = assoc2nodeinfo(sn->sn_assoc_change.sac_assoc_id);
453                         if (ni) {
454                                 spin_lock(&ni->lock);
455                                 ni->assoc_id = 0;
456                                 spin_unlock(&ni->lock);
457                         }
458                 }
459                 break;
460
461                 /* We don't know which INIT failed, so clear the PENDING flags
462                  * on them all.  if assoc_id is zero then it will then try
463                  * again */
464
465                 case SCTP_CANT_STR_ASSOC:
466                 {
467                         log_print("Can't start SCTP association - retrying");
468                         init_failed();
469                 }
470                 break;
471
472                 default:
473                         log_print("unexpected SCTP assoc change id=%d state=%d",
474                                   (int)sn->sn_assoc_change.sac_assoc_id,
475                                   sn->sn_assoc_change.sac_state);
476                 }
477         }
478 }
479
480 /* Data received from remote end */
481 static int receive_from_sock(void)
482 {
483         int ret = 0;
484         struct msghdr msg;
485         struct kvec iov[2];
486         unsigned len;
487         int r;
488         struct sctp_sndrcvinfo *sinfo;
489         struct cmsghdr *cmsg;
490         struct nodeinfo *ni;
491
492         /* These two are marginally too big for stack allocation, but this
493          * function is (currently) only called by dlm_recvd so static should be
494          * OK.
495          */
496         static struct sockaddr_storage msgname;
497         static char incmsg[CMSG_SPACE(sizeof(struct sctp_sndrcvinfo))];
498
499         if (sctp_con.sock == NULL)
500                 goto out;
501
502         if (sctp_con.rx_page == NULL) {
503                 /*
504                  * This doesn't need to be atomic, but I think it should
505                  * improve performance if it is.
506                  */
507                 sctp_con.rx_page = alloc_page(GFP_ATOMIC);
508                 if (sctp_con.rx_page == NULL)
509                         goto out_resched;
510                 CBUF_INIT(&sctp_con.cb, PAGE_CACHE_SIZE);
511         }
512
513         memset(&incmsg, 0, sizeof(incmsg));
514         memset(&msgname, 0, sizeof(msgname));
515
516         memset(incmsg, 0, sizeof(incmsg));
517         msg.msg_name = &msgname;
518         msg.msg_namelen = sizeof(msgname);
519         msg.msg_flags = 0;
520         msg.msg_control = incmsg;
521         msg.msg_controllen = sizeof(incmsg);
522
523         /* I don't see why this circular buffer stuff is necessary for SCTP
524          * which is a packet-based protocol, but the whole thing breaks under
525          * load without it! The overhead is minimal (and is in the TCP lowcomms
526          * anyway, of course) so I'll leave it in until I can figure out what's
527          * really happening.
528          */
529
530         /*
531          * iov[0] is the bit of the circular buffer between the current end
532          * point (cb.base + cb.len) and the end of the buffer.
533          */
534         iov[0].iov_len = sctp_con.cb.base - CBUF_DATA(&sctp_con.cb);
535         iov[0].iov_base = page_address(sctp_con.rx_page) +
536                           CBUF_DATA(&sctp_con.cb);
537         iov[1].iov_len = 0;
538
539         /*
540          * iov[1] is the bit of the circular buffer between the start of the
541          * buffer and the start of the currently used section (cb.base)
542          */
543         if (CBUF_DATA(&sctp_con.cb) >= sctp_con.cb.base) {
544                 iov[0].iov_len = PAGE_CACHE_SIZE - CBUF_DATA(&sctp_con.cb);
545                 iov[1].iov_len = sctp_con.cb.base;
546                 iov[1].iov_base = page_address(sctp_con.rx_page);
547                 msg.msg_iovlen = 2;
548         }
549         len = iov[0].iov_len + iov[1].iov_len;
550
551         r = ret = kernel_recvmsg(sctp_con.sock, &msg, iov, 1, len,
552                                  MSG_NOSIGNAL | MSG_DONTWAIT);
553         if (ret <= 0)
554                 goto out_close;
555
556         msg.msg_control = incmsg;
557         msg.msg_controllen = sizeof(incmsg);
558         cmsg = CMSG_FIRSTHDR(&msg);
559         sinfo = (struct sctp_sndrcvinfo *)CMSG_DATA(cmsg);
560
561         if (msg.msg_flags & MSG_NOTIFICATION) {
562                 process_sctp_notification(&msg, page_address(sctp_con.rx_page));
563                 return 0;
564         }
565
566         /* Is this a new association ? */
567         ni = nodeid2nodeinfo(le32_to_cpu(sinfo->sinfo_ppid), GFP_KERNEL);
568         if (ni) {
569                 ni->assoc_id = sinfo->sinfo_assoc_id;
570                 if (test_and_clear_bit(NI_INIT_PENDING, &ni->flags)) {
571
572                         if (!test_and_set_bit(NI_WRITE_PENDING, &ni->flags)) {
573                                 spin_lock_bh(&write_nodes_lock);
574                                 list_add_tail(&ni->write_list, &write_nodes);
575                                 spin_unlock_bh(&write_nodes_lock);
576                         }
577                         wake_up_process(send_task);
578                 }
579         }
580
581         /* INIT sends a message with length of 1 - ignore it */
582         if (r == 1)
583                 return 0;
584
585         CBUF_ADD(&sctp_con.cb, ret);
586         ret = dlm_process_incoming_buffer(cpu_to_le32(sinfo->sinfo_ppid),
587                                           page_address(sctp_con.rx_page),
588                                           sctp_con.cb.base, sctp_con.cb.len,
589                                           PAGE_CACHE_SIZE);
590         if (ret < 0)
591                 goto out_close;
592         CBUF_EAT(&sctp_con.cb, ret);
593
594       out:
595         ret = 0;
596         goto out_ret;
597
598       out_resched:
599         lowcomms_data_ready(sctp_con.sock->sk, 0);
600         ret = 0;
601         schedule();
602         goto out_ret;
603
604       out_close:
605         if (ret != -EAGAIN)
606                 log_print("error reading from sctp socket: %d", ret);
607       out_ret:
608         return ret;
609 }
610
611 /* Bind to an IP address. SCTP allows multiple address so it can do multi-homing */
612 static int add_bind_addr(struct sockaddr_storage *addr, int addr_len, int num)
613 {
614         mm_segment_t fs;
615         int result = 0;
616
617         fs = get_fs();
618         set_fs(get_ds());
619         if (num == 1)
620                 result = sctp_con.sock->ops->bind(sctp_con.sock,
621                                         (struct sockaddr *) addr, addr_len);
622         else
623                 result = sctp_con.sock->ops->setsockopt(sctp_con.sock, SOL_SCTP,
624                                 SCTP_SOCKOPT_BINDX_ADD, (char *)addr, addr_len);
625         set_fs(fs);
626
627         if (result < 0)
628                 log_print("Can't bind to port %d addr number %d",
629                           dlm_config.tcp_port, num);
630
631         return result;
632 }
633
634 static void init_local(void)
635 {
636         struct sockaddr_storage sas, *addr;
637         int i;
638
639         dlm_local_nodeid = dlm_our_nodeid();
640
641         for (i = 0; i < DLM_MAX_ADDR_COUNT - 1; i++) {
642                 if (dlm_our_addr(&sas, i))
643                         break;
644
645                 addr = kmalloc(sizeof(*addr), GFP_KERNEL);
646                 if (!addr)
647                         break;
648                 memcpy(addr, &sas, sizeof(*addr));
649                 dlm_local_addr[dlm_local_count++] = addr;
650         }
651 }
652
653 /* Initialise SCTP socket and bind to all interfaces */
654 static int init_sock(void)
655 {
656         mm_segment_t fs;
657         struct socket *sock = NULL;
658         struct sockaddr_storage localaddr;
659         struct sctp_event_subscribe subscribe;
660         int result = -EINVAL, num = 1, i, addr_len;
661
662         if (!dlm_local_count) {
663                 init_local();
664                 if (!dlm_local_count) {
665                         log_print("no local IP address has been set");
666                         goto out;
667                 }
668         }
669
670         result = sock_create_kern(dlm_local_addr[0]->ss_family, SOCK_SEQPACKET,
671                                   IPPROTO_SCTP, &sock);
672         if (result < 0) {
673                 log_print("Can't create comms socket, check SCTP is loaded");
674                 goto out;
675         }
676
677         /* Listen for events */
678         memset(&subscribe, 0, sizeof(subscribe));
679         subscribe.sctp_data_io_event = 1;
680         subscribe.sctp_association_event = 1;
681         subscribe.sctp_send_failure_event = 1;
682         subscribe.sctp_shutdown_event = 1;
683         subscribe.sctp_partial_delivery_event = 1;
684
685         fs = get_fs();
686         set_fs(get_ds());
687         result = sock->ops->setsockopt(sock, SOL_SCTP, SCTP_EVENTS,
688                                        (char *)&subscribe, sizeof(subscribe));
689         set_fs(fs);
690
691         if (result < 0) {
692                 log_print("Failed to set SCTP_EVENTS on socket: result=%d",
693                           result);
694                 goto create_delsock;
695         }
696
697         /* Init con struct */
698         sock->sk->sk_user_data = &sctp_con;
699         sctp_con.sock = sock;
700         sctp_con.sock->sk->sk_data_ready = lowcomms_data_ready;
701
702         /* Bind to all interfaces. */
703         for (i = 0; i < dlm_local_count; i++) {
704                 memcpy(&localaddr, dlm_local_addr[i], sizeof(localaddr));
705                 make_sockaddr(&localaddr, dlm_config.tcp_port, &addr_len);
706
707                 result = add_bind_addr(&localaddr, addr_len, num);
708                 if (result)
709                         goto create_delsock;
710                 ++num;
711         }
712
713         result = sock->ops->listen(sock, 5);
714         if (result < 0) {
715                 log_print("Can't set socket listening");
716                 goto create_delsock;
717         }
718
719         return 0;
720
721  create_delsock:
722         sock_release(sock);
723         sctp_con.sock = NULL;
724  out:
725         return result;
726 }
727
728
729 static struct writequeue_entry *new_writequeue_entry(int allocation)
730 {
731         struct writequeue_entry *entry;
732
733         entry = kmalloc(sizeof(struct writequeue_entry), allocation);
734         if (!entry)
735                 return NULL;
736
737         entry->page = alloc_page(allocation);
738         if (!entry->page) {
739                 kfree(entry);
740                 return NULL;
741         }
742
743         entry->offset = 0;
744         entry->len = 0;
745         entry->end = 0;
746         entry->users = 0;
747
748         return entry;
749 }
750
751 void *dlm_lowcomms_get_buffer(int nodeid, int len, int allocation, char **ppc)
752 {
753         struct writequeue_entry *e;
754         int offset = 0;
755         int users = 0;
756         struct nodeinfo *ni;
757
758         if (!atomic_read(&accepting))
759                 return NULL;
760
761         ni = nodeid2nodeinfo(nodeid, allocation);
762         if (!ni)
763                 return NULL;
764
765         spin_lock(&ni->writequeue_lock);
766         e = list_entry(ni->writequeue.prev, struct writequeue_entry, list);
767         if (((struct list_head *) e == &ni->writequeue) ||
768             (PAGE_CACHE_SIZE - e->end < len)) {
769                 e = NULL;
770         } else {
771                 offset = e->end;
772                 e->end += len;
773                 users = e->users++;
774         }
775         spin_unlock(&ni->writequeue_lock);
776
777         if (e) {
778               got_one:
779                 if (users == 0)
780                         kmap(e->page);
781                 *ppc = page_address(e->page) + offset;
782                 return e;
783         }
784
785         e = new_writequeue_entry(allocation);
786         if (e) {
787                 spin_lock(&ni->writequeue_lock);
788                 offset = e->end;
789                 e->end += len;
790                 e->ni = ni;
791                 users = e->users++;
792                 list_add_tail(&e->list, &ni->writequeue);
793                 spin_unlock(&ni->writequeue_lock);
794                 goto got_one;
795         }
796         return NULL;
797 }
798
799 void dlm_lowcomms_commit_buffer(void *arg)
800 {
801         struct writequeue_entry *e = (struct writequeue_entry *) arg;
802         int users;
803         struct nodeinfo *ni = e->ni;
804
805         if (!atomic_read(&accepting))
806                 return;
807
808         spin_lock(&ni->writequeue_lock);
809         users = --e->users;
810         if (users)
811                 goto out;
812         e->len = e->end - e->offset;
813         kunmap(e->page);
814         spin_unlock(&ni->writequeue_lock);
815
816         if (!test_and_set_bit(NI_WRITE_PENDING, &ni->flags)) {
817                 spin_lock_bh(&write_nodes_lock);
818                 list_add_tail(&ni->write_list, &write_nodes);
819                 spin_unlock_bh(&write_nodes_lock);
820                 wake_up_process(send_task);
821         }
822         return;
823
824       out:
825         spin_unlock(&ni->writequeue_lock);
826         return;
827 }
828
829 static void free_entry(struct writequeue_entry *e)
830 {
831         __free_page(e->page);
832         kfree(e);
833 }
834
835 /* Initiate an SCTP association. In theory we could just use sendmsg() on
836    the first IP address and it should work, but this allows us to set up the
837    association before sending any valuable data that we can't afford to lose.
838    It also keeps the send path clean as it can now always use the association ID */
839 static void initiate_association(int nodeid)
840 {
841         struct sockaddr_storage rem_addr;
842         static char outcmsg[CMSG_SPACE(sizeof(struct sctp_sndrcvinfo))];
843         struct msghdr outmessage;
844         struct cmsghdr *cmsg;
845         struct sctp_sndrcvinfo *sinfo;
846         int ret;
847         int addrlen;
848         char buf[1];
849         struct kvec iov[1];
850         struct nodeinfo *ni;
851
852         log_print("Initiating association with node %d", nodeid);
853
854         ni = nodeid2nodeinfo(nodeid, GFP_KERNEL);
855         if (!ni)
856                 return;
857
858         if (nodeid_to_addr(nodeid, (struct sockaddr *)&rem_addr)) {
859                 log_print("no address for nodeid %d", nodeid);
860                 return;
861         }
862
863         make_sockaddr(&rem_addr, dlm_config.tcp_port, &addrlen);
864
865         outmessage.msg_name = &rem_addr;
866         outmessage.msg_namelen = addrlen;
867         outmessage.msg_control = outcmsg;
868         outmessage.msg_controllen = sizeof(outcmsg);
869         outmessage.msg_flags = MSG_EOR;
870
871         iov[0].iov_base = buf;
872         iov[0].iov_len = 1;
873
874         /* Real INIT messages seem to cause trouble. Just send a 1 byte message
875            we can afford to lose */
876         cmsg = CMSG_FIRSTHDR(&outmessage);
877         cmsg->cmsg_level = IPPROTO_SCTP;
878         cmsg->cmsg_type = SCTP_SNDRCV;
879         cmsg->cmsg_len = CMSG_LEN(sizeof(struct sctp_sndrcvinfo));
880         sinfo = (struct sctp_sndrcvinfo *)CMSG_DATA(cmsg);
881         memset(sinfo, 0x00, sizeof(struct sctp_sndrcvinfo));
882         sinfo->sinfo_ppid = cpu_to_le32(dlm_local_nodeid);
883
884         outmessage.msg_controllen = cmsg->cmsg_len;
885         ret = kernel_sendmsg(sctp_con.sock, &outmessage, iov, 1, 1);
886         if (ret < 0) {
887                 log_print("send INIT to node failed: %d", ret);
888                 /* Try again later */
889                 clear_bit(NI_INIT_PENDING, &ni->flags);
890         }
891 }
892
893 /* Send a message */
894 static int send_to_sock(struct nodeinfo *ni)
895 {
896         int ret = 0;
897         struct writequeue_entry *e;
898         int len, offset;
899         struct msghdr outmsg;
900         static char outcmsg[CMSG_SPACE(sizeof(struct sctp_sndrcvinfo))];
901         struct cmsghdr *cmsg;
902         struct sctp_sndrcvinfo *sinfo;
903         struct kvec iov;
904
905         /* See if we need to init an association before we start
906            sending precious messages */
907         spin_lock(&ni->lock);
908         if (!ni->assoc_id && !test_and_set_bit(NI_INIT_PENDING, &ni->flags)) {
909                 spin_unlock(&ni->lock);
910                 initiate_association(ni->nodeid);
911                 return 0;
912         }
913         spin_unlock(&ni->lock);
914
915         outmsg.msg_name = NULL; /* We use assoc_id */
916         outmsg.msg_namelen = 0;
917         outmsg.msg_control = outcmsg;
918         outmsg.msg_controllen = sizeof(outcmsg);
919         outmsg.msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL | MSG_EOR;
920
921         cmsg = CMSG_FIRSTHDR(&outmsg);
922         cmsg->cmsg_level = IPPROTO_SCTP;
923         cmsg->cmsg_type = SCTP_SNDRCV;
924         cmsg->cmsg_len = CMSG_LEN(sizeof(struct sctp_sndrcvinfo));
925         sinfo = (struct sctp_sndrcvinfo *)CMSG_DATA(cmsg);
926         memset(sinfo, 0x00, sizeof(struct sctp_sndrcvinfo));
927         sinfo->sinfo_ppid = cpu_to_le32(dlm_local_nodeid);
928         sinfo->sinfo_assoc_id = ni->assoc_id;
929         outmsg.msg_controllen = cmsg->cmsg_len;
930
931         spin_lock(&ni->writequeue_lock);
932         for (;;) {
933                 if (list_empty(&ni->writequeue))
934                         break;
935                 e = list_entry(ni->writequeue.next, struct writequeue_entry,
936                                list);
937                 len = e->len;
938                 offset = e->offset;
939                 BUG_ON(len == 0 && e->users == 0);
940                 spin_unlock(&ni->writequeue_lock);
941                 kmap(e->page);
942
943                 ret = 0;
944                 if (len) {
945                         iov.iov_base = page_address(e->page)+offset;
946                         iov.iov_len = len;
947
948                         ret = kernel_sendmsg(sctp_con.sock, &outmsg, &iov, 1,
949                                              len);
950                         if (ret == -EAGAIN) {
951                                 sctp_con.eagain_flag = 1;
952                                 goto out;
953                         } else if (ret < 0)
954                                 goto send_error;
955                 } else {
956                         /* Don't starve people filling buffers */
957                         schedule();
958                 }
959
960                 spin_lock(&ni->writequeue_lock);
961                 e->offset += ret;
962                 e->len -= ret;
963
964                 if (e->len == 0 && e->users == 0) {
965                         list_del(&e->list);
966                         free_entry(e);
967                         continue;
968                 }
969         }
970         spin_unlock(&ni->writequeue_lock);
971  out:
972         return ret;
973
974  send_error:
975         log_print("Error sending to node %d %d", ni->nodeid, ret);
976         spin_lock(&ni->lock);
977         if (!test_and_set_bit(NI_INIT_PENDING, &ni->flags)) {
978                 ni->assoc_id = 0;
979                 spin_unlock(&ni->lock);
980                 initiate_association(ni->nodeid);
981         } else
982                 spin_unlock(&ni->lock);
983
984         return ret;
985 }
986
987 /* Try to send any messages that are pending */
988 static void process_output_queue(void)
989 {
990         struct list_head *list;
991         struct list_head *temp;
992
993         spin_lock_bh(&write_nodes_lock);
994         list_for_each_safe(list, temp, &write_nodes) {
995                 struct nodeinfo *ni =
996                     list_entry(list, struct nodeinfo, write_list);
997                 clear_bit(NI_WRITE_PENDING, &ni->flags);
998                 list_del(&ni->write_list);
999
1000                 spin_unlock_bh(&write_nodes_lock);
1001
1002                 send_to_sock(ni);
1003                 spin_lock_bh(&write_nodes_lock);
1004         }
1005         spin_unlock_bh(&write_nodes_lock);
1006 }
1007
1008 /* Called after we've had -EAGAIN and been woken up */
1009 static void refill_write_queue(void)
1010 {
1011         int i;
1012
1013         for (i=1; i<=max_nodeid; i++) {
1014                 struct nodeinfo *ni = nodeid2nodeinfo(i, 0);
1015
1016                 if (ni) {
1017                         if (!test_and_set_bit(NI_WRITE_PENDING, &ni->flags)) {
1018                                 spin_lock_bh(&write_nodes_lock);
1019                                 list_add_tail(&ni->write_list, &write_nodes);
1020                                 spin_unlock_bh(&write_nodes_lock);
1021                         }
1022                 }
1023         }
1024 }
1025
1026 static void clean_one_writequeue(struct nodeinfo *ni)
1027 {
1028         struct list_head *list;
1029         struct list_head *temp;
1030
1031         spin_lock(&ni->writequeue_lock);
1032         list_for_each_safe(list, temp, &ni->writequeue) {
1033                 struct writequeue_entry *e =
1034                         list_entry(list, struct writequeue_entry, list);
1035                 list_del(&e->list);
1036                 free_entry(e);
1037         }
1038         spin_unlock(&ni->writequeue_lock);
1039 }
1040
1041 static void clean_writequeues(void)
1042 {
1043         int i;
1044
1045         for (i=1; i<=max_nodeid; i++) {
1046                 struct nodeinfo *ni = nodeid2nodeinfo(i, 0);
1047                 if (ni)
1048                         clean_one_writequeue(ni);
1049         }
1050 }
1051
1052
1053 static void dealloc_nodeinfo(void)
1054 {
1055         int i;
1056
1057         for (i=1; i<=max_nodeid; i++) {
1058                 struct nodeinfo *ni = nodeid2nodeinfo(i, 0);
1059                 if (ni) {
1060                         idr_remove(&nodeinfo_idr, i);
1061                         kfree(ni);
1062                 }
1063         }
1064 }
1065
1066 int dlm_lowcomms_close(int nodeid)
1067 {
1068         struct nodeinfo *ni;
1069
1070         ni = nodeid2nodeinfo(nodeid, 0);
1071         if (!ni)
1072                 return -1;
1073
1074         spin_lock(&ni->lock);
1075         if (ni->assoc_id) {
1076                 ni->assoc_id = 0;
1077                 /* Don't send shutdown here, sctp will just queue it
1078                    till the node comes back up! */
1079         }
1080         spin_unlock(&ni->lock);
1081
1082         clean_one_writequeue(ni);
1083         clear_bit(NI_INIT_PENDING, &ni->flags);
1084         return 0;
1085 }
1086
1087 static int write_list_empty(void)
1088 {
1089         int status;
1090
1091         spin_lock_bh(&write_nodes_lock);
1092         status = list_empty(&write_nodes);
1093         spin_unlock_bh(&write_nodes_lock);
1094
1095         return status;
1096 }
1097
1098 static int dlm_recvd(void *data)
1099 {
1100         DECLARE_WAITQUEUE(wait, current);
1101
1102         while (!kthread_should_stop()) {
1103                 int count = 0;
1104
1105                 set_current_state(TASK_INTERRUPTIBLE);
1106                 add_wait_queue(&lowcomms_recv_wait, &wait);
1107                 if (!test_bit(CF_READ_PENDING, &sctp_con.flags))
1108                         schedule();
1109                 remove_wait_queue(&lowcomms_recv_wait, &wait);
1110                 set_current_state(TASK_RUNNING);
1111
1112                 if (test_and_clear_bit(CF_READ_PENDING, &sctp_con.flags)) {
1113                         int ret;
1114
1115                         do {
1116                                 ret = receive_from_sock();
1117
1118                                 /* Don't starve out everyone else */
1119                                 if (++count >= MAX_RX_MSG_COUNT) {
1120                                         schedule();
1121                                         count = 0;
1122                                 }
1123                         } while (!kthread_should_stop() && ret >=0);
1124                 }
1125                 schedule();
1126         }
1127
1128         return 0;
1129 }
1130
1131 static int dlm_sendd(void *data)
1132 {
1133         DECLARE_WAITQUEUE(wait, current);
1134
1135         add_wait_queue(sctp_con.sock->sk->sk_sleep, &wait);
1136
1137         while (!kthread_should_stop()) {
1138                 set_current_state(TASK_INTERRUPTIBLE);
1139                 if (write_list_empty())
1140                         schedule();
1141                 set_current_state(TASK_RUNNING);
1142
1143                 if (sctp_con.eagain_flag) {
1144                         sctp_con.eagain_flag = 0;
1145                         refill_write_queue();
1146                 }
1147                 process_output_queue();
1148         }
1149
1150         remove_wait_queue(sctp_con.sock->sk->sk_sleep, &wait);
1151
1152         return 0;
1153 }
1154
1155 static void daemons_stop(void)
1156 {
1157         kthread_stop(recv_task);
1158         kthread_stop(send_task);
1159 }
1160
1161 static int daemons_start(void)
1162 {
1163         struct task_struct *p;
1164         int error;
1165
1166         p = kthread_run(dlm_recvd, NULL, "dlm_recvd");
1167         error = IS_ERR(p);
1168         if (error) {
1169                 log_print("can't start dlm_recvd %d", error);
1170                 return error;
1171         }
1172         recv_task = p;
1173
1174         p = kthread_run(dlm_sendd, NULL, "dlm_sendd");
1175         error = IS_ERR(p);
1176         if (error) {
1177                 log_print("can't start dlm_sendd %d", error);
1178                 kthread_stop(recv_task);
1179                 return error;
1180         }
1181         send_task = p;
1182
1183         return 0;
1184 }
1185
1186 /*
1187  * This is quite likely to sleep...
1188  */
1189 int dlm_lowcomms_start(void)
1190 {
1191         int error;
1192
1193         error = init_sock();
1194         if (error)
1195                 goto fail_sock;
1196         error = daemons_start();
1197         if (error)
1198                 goto fail_sock;
1199         atomic_set(&accepting, 1);
1200         return 0;
1201
1202  fail_sock:
1203         close_connection();
1204         return error;
1205 }
1206
1207 /* Set all the activity flags to prevent any socket activity. */
1208
1209 void dlm_lowcomms_stop(void)
1210 {
1211         atomic_set(&accepting, 0);
1212         sctp_con.flags = 0x7;
1213         daemons_stop();
1214         clean_writequeues();
1215         close_connection();
1216         dealloc_nodeinfo();
1217         max_nodeid = 0;
1218 }
1219
1220 int dlm_lowcomms_init(void)
1221 {
1222         init_waitqueue_head(&lowcomms_recv_wait);
1223         spin_lock_init(&write_nodes_lock);
1224         INIT_LIST_HEAD(&write_nodes);
1225         init_rwsem(&nodeinfo_lock);
1226         return 0;
1227 }
1228
1229 void dlm_lowcomms_exit(void)
1230 {
1231         int i;
1232
1233         for (i = 0; i < dlm_local_count; i++)
1234                 kfree(dlm_local_addr[i]);
1235         dlm_local_count = 0;
1236         dlm_local_nodeid = 0;
1237 }
1238