Release 980927
[wine] / server / socket.c
1 /*
2  * Server-side socket communication functions
3  *
4  * Copyright (C) 1998 Alexandre Julliard
5  */
6
7 #include <assert.h>
8 #include <errno.h>
9 #include <fcntl.h>
10 #include <signal.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <stdarg.h>
14 #include <string.h>
15 #include <sys/time.h>
16 #include <sys/types.h>
17 #include <sys/socket.h>
18 #include <sys/uio.h>
19 #include <unistd.h>
20
21 #include "config.h"
22 #include "server.h"
23
24 #include "server/object.h"
25
26 /* Some versions of glibc don't define this */
27 #ifndef SCM_RIGHTS
28 #define SCM_RIGHTS 1
29 #endif
30
31 /* client state */
32 enum state
33 {
34     RUNNING,   /* running normally */
35     SENDING,   /* sending us a request */
36     WAITING,   /* waiting for us to reply */
37     READING    /* reading our reply */
38 };
39
40 /* client timeout */
41 struct timeout
42 {
43     struct timeval  when;    /* timeout expiry (absolute time) */
44     struct timeout *next;    /* next in sorted list */
45     struct timeout *prev;    /* prev in sorted list */
46     int             client;  /* client id */
47 };
48
49 /* client structure */
50 struct client
51 {
52     enum state         state;        /* client state */
53     unsigned int       seq;          /* current sequence number */
54     struct header      head;         /* current msg header */
55     char              *data;         /* current msg data */
56     int                count;        /* bytes sent/received so far */
57     int                pass_fd;      /* fd to pass to and from the client */
58     struct thread     *self;         /* client thread (opaque pointer) */
59     struct timeout     timeout;      /* client timeout */
60 };
61
62
63 static struct client *clients[FD_SETSIZE];  /* clients array */
64 static fd_set read_set, write_set;          /* current select sets */
65 static int nb_clients;                      /* current number of clients */
66 static int max_fd;                          /* max fd in use */
67 static int initial_client_fd;               /* fd of the first client */
68 static struct timeout *timeout_head;        /* sorted timeouts list head */
69 static struct timeout *timeout_tail;        /* sorted timeouts list tail */
70
71 /* exit code passed to remove_client */
72 #define OUT_OF_MEMORY  -1
73 #define BROKEN_PIPE    -2
74 #define PROTOCOL_ERROR -3
75
76
77 /* signal a client protocol error */
78 static void protocol_error( int client_fd, const char *err, ... )
79 {
80     va_list args;
81
82     va_start( args, err );
83     fprintf( stderr, "Protocol error:%d: ", client_fd );
84     vfprintf( stderr, err, args );
85     va_end( args );
86 }
87
88
89 /* send a message to a client that is ready to receive something */
90 static void do_write( int client_fd )
91 {
92     struct client *client = clients[client_fd];
93     struct iovec vec[2];
94 #ifndef HAVE_MSGHDR_ACCRIGHTS
95     struct cmsg_fd cmsg  = { sizeof(cmsg), SOL_SOCKET, SCM_RIGHTS,
96                              client->pass_fd };
97 #endif
98     struct msghdr msghdr = { NULL, 0, vec, 2, };
99     int ret;
100
101     /* make sure we have something to send */
102     assert( client->count < client->head.len );
103     /* make sure the client is listening */
104     assert( client->state == READING );
105
106     if (client->count < sizeof(client->head))
107     {
108         vec[0].iov_base = (char *)&client->head + client->count;
109         vec[0].iov_len  = sizeof(client->head) - client->count;
110         vec[1].iov_base = client->data;
111         vec[1].iov_len  = client->head.len - sizeof(client->head);
112     }
113     else
114     {
115         vec[0].iov_base = client->data + client->count - sizeof(client->head);
116         vec[0].iov_len  = client->head.len - client->count;
117         msghdr.msg_iovlen = 1;
118     }
119     if (client->pass_fd != -1)  /* we have an fd to send */
120     {
121 #ifdef HAVE_MSGHDR_ACCRIGHTS
122         msghdr.msg_accrights = (void *)&client->pass_fd;
123         msghdr.msg_accrightslen = sizeof(client->pass_fd);
124 #else
125         msghdr.msg_control = &cmsg;
126         msghdr.msg_controllen = sizeof(cmsg);
127 #endif
128     }
129     ret = sendmsg( client_fd, &msghdr, 0 );
130     if (ret == -1)
131     {
132         if (errno != EPIPE) perror("sendmsg");
133         remove_client( client_fd, BROKEN_PIPE );
134         return;
135     }
136     if (client->pass_fd != -1)  /* We sent the fd, now we can close it */
137     {
138         close( client->pass_fd );
139         client->pass_fd = -1;
140     }
141     if ((client->count += ret) < client->head.len) return;
142
143     /* we have finished with this message */
144     if (client->data) free( client->data );
145     client->data  = NULL;
146     client->count = 0;
147     client->state = RUNNING;
148     client->seq++;
149     FD_CLR( client_fd, &write_set );
150     FD_SET( client_fd, &read_set );
151 }
152
153
154 /* read a message from a client that has something to say */
155 static void do_read( int client_fd )
156 {
157     struct client *client = clients[client_fd];
158     struct iovec vec;
159     int pass_fd = -1;
160 #ifdef HAVE_MSGHDR_ACCRIGHTS
161     struct msghdr msghdr = { NULL, 0, &vec, 1, (void*)&pass_fd, sizeof(int) };
162 #else
163     struct cmsg_fd cmsg  = { sizeof(cmsg), SOL_SOCKET, SCM_RIGHTS, -1 };
164     struct msghdr msghdr = { NULL, 0, &vec, 1, &cmsg, sizeof(cmsg), 0 };
165 #endif
166     int ret;
167
168     if (client->count < sizeof(client->head))
169     {
170         vec.iov_base = (char *)&client->head + client->count;
171         vec.iov_len  = sizeof(client->head) - client->count;
172     }
173     else
174     {
175         if (!client->data &&
176             !(client->data = malloc(client->head.len-sizeof(client->head))))
177         {
178             remove_client( client_fd, OUT_OF_MEMORY );
179             return;
180         }
181         vec.iov_base = client->data + client->count - sizeof(client->head);
182         vec.iov_len  = client->head.len - client->count;
183     }
184
185     ret = recvmsg( client_fd, &msghdr, 0 );
186     if (ret == -1)
187     {
188         perror("recvmsg");
189         remove_client( client_fd, BROKEN_PIPE );
190         return;
191     }
192 #ifndef HAVE_MSGHDR_ACCRIGHTS
193     pass_fd = cmsg.fd;
194 #endif
195     if (pass_fd != -1)
196     {
197         /* can only receive one fd per message */
198         if (client->pass_fd != -1) close( client->pass_fd );
199         client->pass_fd = pass_fd;
200     }
201     else if (!ret)  /* closed pipe */
202     {
203         remove_client( client_fd, BROKEN_PIPE );
204         return;
205     }
206
207     if (client->state == RUNNING) client->state = SENDING;
208     assert( client->state == SENDING );
209
210     client->count += ret;
211
212     /* received the complete header yet? */
213     if (client->count < sizeof(client->head)) return;
214
215     /* sanity checks */
216     if (client->head.seq != client->seq)
217     {
218         protocol_error( client_fd, "bad sequence %08x instead of %08x\n",
219                         client->head.seq, client->seq );
220         remove_client( client_fd, PROTOCOL_ERROR );
221         return;
222     }
223     if ((client->head.len < sizeof(client->head)) ||
224         (client->head.len > MAX_MSG_LENGTH + sizeof(client->head)))
225     {
226         protocol_error( client_fd, "bad header length %08x\n",
227                         client->head.len );
228         remove_client( client_fd, PROTOCOL_ERROR );
229         return;
230     }
231
232     /* received the whole message? */
233     if (client->count == client->head.len)
234     {
235         /* done reading the data, call the callback function */
236
237         int len = client->head.len - sizeof(client->head);
238         char *data = client->data;
239         int passed_fd = client->pass_fd;
240         enum request type = client->head.type;
241
242         /* clear the info now, as the client may be deleted by the callback */
243         client->head.len  = 0;
244         client->head.type = 0;
245         client->count     = 0;
246         client->data      = NULL;
247         client->pass_fd   = -1;
248         client->state     = WAITING;
249         client->seq++;
250
251         call_req_handler( client->self, type, data, len, passed_fd );
252         if (passed_fd != -1) close( passed_fd );
253         if (data) free( data );
254     }
255 }
256
257
258 /* handle a client timeout */
259 static void do_timeout( int client_fd )
260 {
261     struct client *client = clients[client_fd];
262     set_timeout( client_fd, 0 );  /* Remove the timeout */
263     call_timeout_handler( client->self );
264 }
265
266
267 /* server main loop */
268 void server_main_loop( int fd )
269 {
270     int i, ret;
271
272     setsid();
273     signal( SIGPIPE, SIG_IGN );
274
275     /* special magic to create the initial thread */
276     initial_client_fd = fd;
277     add_client( initial_client_fd, NULL );
278
279     while (nb_clients)
280     {
281         fd_set read = read_set, write = write_set;
282 #if 0
283         printf( "select: " );
284         for (i = 0; i <= max_fd; i++) printf( "%c", FD_ISSET( i, &read_set ) ? 'r' :
285                                                     (FD_ISSET( i, &write_set ) ? 'w' : '-') );
286         printf( "\n" );
287 #endif
288         if (timeout_head)
289         {
290             struct timeval tv, now;
291             gettimeofday( &now, NULL );
292             if ((timeout_head->when.tv_sec < now.tv_sec) ||
293                 ((timeout_head->when.tv_sec == now.tv_sec) &&
294                  (timeout_head->when.tv_usec < now.tv_usec)))
295             {
296                 do_timeout( timeout_head->client );
297                 continue;
298             }
299             tv.tv_sec = timeout_head->when.tv_sec - now.tv_sec;
300             if ((tv.tv_usec = timeout_head->when.tv_usec - now.tv_usec) < 0)
301             {
302                 tv.tv_usec += 1000000;
303                 tv.tv_sec--;
304             }
305             ret = select( max_fd + 1, &read, &write, NULL, &tv );
306         }
307         else  /* no timeout */
308         {
309             ret = select( max_fd + 1, &read, &write, NULL, NULL );
310         }
311
312         if (!ret) continue;
313         if (ret == -1) perror("select");
314
315         for (i = 0; i <= max_fd; i++)
316         {
317             if (FD_ISSET( i, &write ))
318             {
319                 if (clients[i]) do_write( i );
320             }
321             else if (FD_ISSET( i, &read ))
322             {
323                 if (clients[i]) do_read( i );
324             }
325         }
326     }
327 }
328
329
330 /*******************************************************************/
331 /* server-side exported functions                                  */
332
333 /* add a client */
334 int add_client( int client_fd, struct thread *self )
335 {
336     int flags;
337     struct client *client = malloc( sizeof(*client) );
338     if (!client) return -1;
339     assert( !clients[client_fd] );
340
341     client->state                = RUNNING;
342     client->seq                  = 0;
343     client->head.len             = 0;
344     client->head.type            = 0;
345     client->count                = 0;
346     client->data                 = NULL;
347     client->self                 = self;
348     client->pass_fd              = -1;
349     client->timeout.when.tv_sec  = 0;
350     client->timeout.when.tv_usec = 0;
351     client->timeout.client       = client_fd;
352
353     flags = fcntl( client_fd, F_GETFL, 0 );
354     fcntl( client_fd, F_SETFL, flags | O_NONBLOCK );
355
356     clients[client_fd] = client;
357     FD_SET( client_fd, &read_set );
358     if (client_fd > max_fd) max_fd = client_fd;
359     nb_clients++;
360     return client_fd;
361 }
362
363 /* remove a client */
364 void remove_client( int client_fd, int exit_code )
365 {
366     struct client *client = clients[client_fd];
367     assert( client );
368
369     call_kill_handler( client->self, exit_code );
370
371     set_timeout( client_fd, 0 );
372     clients[client_fd] = NULL;
373     FD_CLR( client_fd, &read_set );
374     FD_CLR( client_fd, &write_set );
375     if (max_fd == client_fd) while (max_fd && !clients[max_fd]) max_fd--;
376     if (initial_client_fd == client_fd) initial_client_fd = -1;
377     close( client_fd );
378     nb_clients--;
379
380     /* Purge messages */
381     if (client->data) free( client->data );
382     if (client->pass_fd != -1) close( client->pass_fd );
383     free( client );
384 }
385
386 /* return the fd of the initial client */
387 int get_initial_client_fd(void)
388 {
389     assert( initial_client_fd != -1 );
390     return initial_client_fd;
391 }
392
393 /* set a client timeout */
394 void set_timeout( int client_fd, struct timeval *when )
395 {
396     struct timeout *tm, *pos;
397     struct client *client = clients[client_fd];
398     assert( client );
399
400     tm = &client->timeout;
401     if (tm->when.tv_sec || tm->when.tv_usec)
402     {
403         /* there is already a timeout */
404         if (tm->next) tm->next->prev = tm->prev;
405         else timeout_tail = tm->prev;
406         if (tm->prev) tm->prev->next = tm->next;
407         else timeout_head = tm->next;
408         tm->when.tv_sec = tm->when.tv_usec = 0;
409     }
410     if (!when) return;  /* no timeout */
411     tm->when = *when;
412
413     /* Now insert it in the linked list */
414
415     for (pos = timeout_head; pos; pos = pos->next)
416     {
417         if (pos->when.tv_sec > tm->when.tv_sec) break;
418         if ((pos->when.tv_sec == tm->when.tv_sec) &&
419             (pos->when.tv_usec > tm->when.tv_usec)) break;
420     }
421
422     if (pos)  /* insert it before 'pos' */
423     {
424         if ((tm->prev = pos->prev)) tm->prev->next = tm;
425         else timeout_head = tm;
426         tm->next = pos;
427         pos->prev = tm;
428     }
429     else  /* insert it at the tail */
430     {
431         tm->next = NULL;
432         if (timeout_tail) timeout_tail->next = tm;
433         else timeout_head = tm;
434         tm->prev = timeout_tail;
435         timeout_tail = tm;
436     }
437 }
438
439
440 /* send a reply to a client */
441 int send_reply_v( int client_fd, int type, int pass_fd,
442                   struct iovec *vec, int veclen )
443 {
444     int i;
445     unsigned int len;
446     char *p;
447     struct client *client = clients[client_fd];
448
449     assert( client );
450     assert( client->state == WAITING );
451     assert( !client->data );
452
453     if (debug_level) trace_reply( client->self, type, pass_fd, vec, veclen );
454
455     for (i = len = 0; i < veclen; i++) len += vec[i].iov_len;
456     assert( len < MAX_MSG_LENGTH );
457
458     if (len && !(client->data = malloc( len ))) return -1;
459     client->count     = 0;
460     client->head.len  = len + sizeof(client->head);
461     client->head.type = type;
462     client->head.seq  = client->seq;
463     client->pass_fd   = pass_fd;
464
465     for (i = 0, p = client->data; i < veclen; i++)
466     {
467         memcpy( p, vec[i].iov_base, vec[i].iov_len );
468         p += vec[i].iov_len;
469     }
470
471     client->state = READING;
472     FD_CLR( client_fd, &read_set );
473     FD_SET( client_fd, &write_set );
474     return 0;
475 }