- New implementation of SendMessage, ReceiveMessage, ReplyMessage functions
[wine] / server / socket.c
1 /*
2  * Server-side socket communication functions
3  *
4  * Copyright (C) 1998 Alexandre Julliard
5  */
6
7 #include <assert.h>
8 #include <errno.h>
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <stdarg.h>
12 #include <string.h>
13 #include <sys/time.h>
14 #include <sys/types.h>
15 #include <sys/socket.h>
16 #include <sys/uio.h>
17 #include <unistd.h>
18
19 #include "config.h"
20 #include "server.h"
21
22 #include "server/object.h"
23
24 /* Some versions of glibc don't define this */
25 #ifndef SCM_RIGHTS
26 #define SCM_RIGHTS 1
27 #endif
28
29 /* client state */
30 enum state
31 {
32     RUNNING,   /* running normally */
33     SENDING,   /* sending us a request */
34     WAITING,   /* waiting for us to reply */
35     READING    /* reading our reply */
36 };
37
38 /* client structure */
39 struct client
40 {
41     enum state         state;        /* client state */
42     unsigned int       seq;          /* current sequence number */
43     struct header      head;         /* current msg header */
44     char              *data;         /* current msg data */
45     int                count;        /* bytes sent/received so far */
46     int                pass_fd;      /* fd to pass to and from the client */
47     struct thread     *self;         /* client thread (opaque pointer) */
48 };
49
50 static int initial_client_fd;               /* fd of the first client */
51
52 /* exit code passed to remove_client */
53 #define OUT_OF_MEMORY  -1
54 #define BROKEN_PIPE    -2
55 #define PROTOCOL_ERROR -3
56
57
58 /* signal a client protocol error */
59 static void protocol_error( int client_fd, const char *err, ... )
60 {
61     va_list args;
62
63     va_start( args, err );
64     fprintf( stderr, "Protocol error:%d: ", client_fd );
65     vfprintf( stderr, err, args );
66     va_end( args );
67 }
68
69 /* send a message to a client that is ready to receive something */
70 static void do_write( struct client *client, int client_fd )
71 {
72     struct iovec vec[2];
73 #ifndef HAVE_MSGHDR_ACCRIGHTS
74     struct cmsg_fd cmsg  = { sizeof(cmsg), SOL_SOCKET, SCM_RIGHTS,
75                              client->pass_fd };
76 #endif
77     struct msghdr msghdr = { NULL, 0, vec, 2, };
78     int ret;
79
80     /* make sure we have something to send */
81     assert( client->count < client->head.len );
82     /* make sure the client is listening */
83     assert( client->state == READING );
84
85     if (client->count < sizeof(client->head))
86     {
87         vec[0].iov_base = (char *)&client->head + client->count;
88         vec[0].iov_len  = sizeof(client->head) - client->count;
89         vec[1].iov_base = client->data;
90         vec[1].iov_len  = client->head.len - sizeof(client->head);
91     }
92     else
93     {
94         vec[0].iov_base = client->data + client->count - sizeof(client->head);
95         vec[0].iov_len  = client->head.len - client->count;
96         msghdr.msg_iovlen = 1;
97     }
98     if (client->pass_fd != -1)  /* we have an fd to send */
99     {
100 #ifdef HAVE_MSGHDR_ACCRIGHTS
101         msghdr.msg_accrights = (void *)&client->pass_fd;
102         msghdr.msg_accrightslen = sizeof(client->pass_fd);
103 #else
104         msghdr.msg_control = &cmsg;
105         msghdr.msg_controllen = sizeof(cmsg);
106 #endif
107     }
108     ret = sendmsg( client_fd, &msghdr, 0 );
109     if (ret == -1)
110     {
111         if (errno != EPIPE) perror("sendmsg");
112         remove_client( client_fd, BROKEN_PIPE );
113         return;
114     }
115     if (client->pass_fd != -1)  /* We sent the fd, now we can close it */
116     {
117         close( client->pass_fd );
118         client->pass_fd = -1;
119     }
120     if ((client->count += ret) < client->head.len) return;
121
122     /* we have finished with this message */
123     if (client->data) free( client->data );
124     client->data  = NULL;
125     client->count = 0;
126     client->state = RUNNING;
127     client->seq++;
128     set_select_events( client_fd, READ_EVENT );
129 }
130
131
132 /* read a message from a client that has something to say */
133 static void do_read( struct client *client, int client_fd )
134 {
135     struct iovec vec;
136     int pass_fd = -1;
137 #ifdef HAVE_MSGHDR_ACCRIGHTS
138     struct msghdr msghdr = { NULL, 0, &vec, 1, (void*)&pass_fd, sizeof(int) };
139 #else
140     struct cmsg_fd cmsg  = { sizeof(cmsg), SOL_SOCKET, SCM_RIGHTS, -1 };
141     struct msghdr msghdr = { NULL, 0, &vec, 1, &cmsg, sizeof(cmsg), 0 };
142 #endif
143     int ret;
144
145     if (client->count < sizeof(client->head))
146     {
147         vec.iov_base = (char *)&client->head + client->count;
148         vec.iov_len  = sizeof(client->head) - client->count;
149     }
150     else
151     {
152         if (!client->data &&
153             !(client->data = malloc(client->head.len-sizeof(client->head))))
154         {
155             remove_client( client_fd, OUT_OF_MEMORY );
156             return;
157         }
158         vec.iov_base = client->data + client->count - sizeof(client->head);
159         vec.iov_len  = client->head.len - client->count;
160     }
161
162     ret = recvmsg( client_fd, &msghdr, 0 );
163     if (ret == -1)
164     {
165         perror("recvmsg");
166         remove_client( client_fd, BROKEN_PIPE );
167         return;
168     }
169 #ifndef HAVE_MSGHDR_ACCRIGHTS
170     pass_fd = cmsg.fd;
171 #endif
172     if (pass_fd != -1)
173     {
174         /* can only receive one fd per message */
175         if (client->pass_fd != -1) close( client->pass_fd );
176         client->pass_fd = pass_fd;
177     }
178     else if (!ret)  /* closed pipe */
179     {
180         remove_client( client_fd, BROKEN_PIPE );
181         return;
182     }
183
184     if (client->state == RUNNING) client->state = SENDING;
185     assert( client->state == SENDING );
186
187     client->count += ret;
188
189     /* received the complete header yet? */
190     if (client->count < sizeof(client->head)) return;
191
192     /* sanity checks */
193     if (client->head.seq != client->seq)
194     {
195         protocol_error( client_fd, "bad sequence %08x instead of %08x\n",
196                         client->head.seq, client->seq );
197         remove_client( client_fd, PROTOCOL_ERROR );
198         return;
199     }
200     if ((client->head.len < sizeof(client->head)) ||
201         (client->head.len > MAX_MSG_LENGTH + sizeof(client->head)))
202     {
203         protocol_error( client_fd, "bad header length %08x\n",
204                         client->head.len );
205         remove_client( client_fd, PROTOCOL_ERROR );
206         return;
207     }
208
209     /* received the whole message? */
210     if (client->count == client->head.len)
211     {
212         /* done reading the data, call the callback function */
213
214         int len = client->head.len - sizeof(client->head);
215         char *data = client->data;
216         int passed_fd = client->pass_fd;
217         enum request type = client->head.type;
218
219         /* clear the info now, as the client may be deleted by the callback */
220         client->head.len  = 0;
221         client->head.type = 0;
222         client->count     = 0;
223         client->data      = NULL;
224         client->pass_fd   = -1;
225         client->state     = WAITING;
226         client->seq++;
227
228         call_req_handler( client->self, type, data, len, passed_fd );
229         if (passed_fd != -1) close( passed_fd );
230         if (data) free( data );
231     }
232 }
233
234 /* handle a client timeout */
235 static void client_timeout( int client_fd, void *private )
236 {
237     struct client *client = (struct client *)private;
238     set_select_timeout( client_fd, 0 );  /* Remove the timeout */
239     call_timeout_handler( client->self );
240 }
241
242 /* handle a client event */
243 static void client_event( int client_fd, int event, void *private )
244 {
245     struct client *client = (struct client *)private;
246     if (event & WRITE_EVENT)
247         do_write( client, client_fd );
248     if (event & READ_EVENT)
249         do_read( client, client_fd );
250 }
251
252 static const struct select_ops client_ops =
253 {
254     client_event,
255     client_timeout
256 };
257
258 /*******************************************************************/
259 /* server-side exported functions                                  */
260
261 /* server initialization */
262 void server_init( int fd )
263 {
264     /* special magic to create the initial thread */
265     initial_client_fd = fd;
266     add_client( initial_client_fd, NULL );
267 }
268
269
270 /* add a client */
271 int add_client( int client_fd, struct thread *self )
272 {
273     struct client *client = malloc( sizeof(*client) );
274     if (!client) return -1;
275
276     client->state                = RUNNING;
277     client->seq                  = 0;
278     client->head.len             = 0;
279     client->head.type            = 0;
280     client->count                = 0;
281     client->data                 = NULL;
282     client->self                 = self;
283     client->pass_fd              = -1;
284
285     if (add_select_user( client_fd, READ_EVENT, &client_ops, client ) == -1)
286     {
287         free( client );
288         return -1;
289     }
290     return client_fd;
291 }
292
293 /* remove a client */
294 void remove_client( int client_fd, int exit_code )
295 {
296     struct client *client = (struct client *)get_select_private_data( &client_ops, client_fd );
297     assert( client );
298
299     call_kill_handler( client->self, exit_code );
300
301     remove_select_user( client_fd );
302     if (initial_client_fd == client_fd) initial_client_fd = -1;
303     close( client_fd );
304
305     /* Purge messages */
306     if (client->data) free( client->data );
307     if (client->pass_fd != -1) close( client->pass_fd );
308     free( client );
309 }
310
311 /* return the fd of the initial client */
312 int get_initial_client_fd(void)
313 {
314     assert( initial_client_fd != -1 );
315     return initial_client_fd;
316 }
317
318 /* send a reply to a client */
319 int send_reply_v( int client_fd, int type, int pass_fd,
320                   struct iovec *vec, int veclen )
321 {
322     int i;
323     unsigned int len;
324     char *p;
325     struct client *client = (struct client *)get_select_private_data( &client_ops, client_fd );
326
327     assert( client );
328     assert( client->state == WAITING );
329     assert( !client->data );
330
331     if (debug_level) trace_reply( client->self, type, pass_fd, vec, veclen );
332
333     for (i = len = 0; i < veclen; i++) len += vec[i].iov_len;
334     assert( len < MAX_MSG_LENGTH );
335
336     if (len && !(client->data = malloc( len ))) return -1;
337     client->count     = 0;
338     client->head.len  = len + sizeof(client->head);
339     client->head.type = type;
340     client->head.seq  = client->seq;
341     client->pass_fd   = pass_fd;
342
343     for (i = 0, p = client->data; i < veclen; i++)
344     {
345         memcpy( p, vec[i].iov_base, vec[i].iov_len );
346         p += vec[i].iov_len;
347     }
348
349     client->state = READING;
350     set_select_events( client_fd, WRITE_EVENT );
351     return 0;
352 }