Merge branch 'linus' into core/futexes
[linux-2.6] / net / phonet / socket.c
1 /*
2  * File: socket.c
3  *
4  * Phonet sockets
5  *
6  * Copyright (C) 2008 Nokia Corporation.
7  *
8  * Contact: Remi Denis-Courmont <remi.denis-courmont@nokia.com>
9  * Original author: Sakari Ailus <sakari.ailus@nokia.com>
10  *
11  * This program is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU General Public License
13  * version 2 as published by the Free Software Foundation.
14  *
15  * This program is distributed in the hope that it will be useful, but
16  * WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * General Public License for more details.
19  *
20  * You should have received a copy of the GNU General Public License
21  * along with this program; if not, write to the Free Software
22  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
23  * 02110-1301 USA
24  */
25
26 #include <linux/kernel.h>
27 #include <linux/net.h>
28 #include <linux/poll.h>
29 #include <net/sock.h>
30 #include <net/tcp_states.h>
31
32 #include <linux/phonet.h>
33 #include <net/phonet/phonet.h>
34 #include <net/phonet/pep.h>
35 #include <net/phonet/pn_dev.h>
36
37 static int pn_socket_release(struct socket *sock)
38 {
39         struct sock *sk = sock->sk;
40
41         if (sk) {
42                 sock->sk = NULL;
43                 sk->sk_prot->close(sk, 0);
44         }
45         return 0;
46 }
47
48 static struct  {
49         struct hlist_head hlist;
50         spinlock_t lock;
51 } pnsocks = {
52         .hlist = HLIST_HEAD_INIT,
53         .lock = __SPIN_LOCK_UNLOCKED(pnsocks.lock),
54 };
55
56 /*
57  * Find address based on socket address, match only certain fields.
58  * Also grab sock if it was found. Remember to sock_put it later.
59  */
60 struct sock *pn_find_sock_by_sa(struct net *net, const struct sockaddr_pn *spn)
61 {
62         struct hlist_node *node;
63         struct sock *sknode;
64         struct sock *rval = NULL;
65         u16 obj = pn_sockaddr_get_object(spn);
66         u8 res = spn->spn_resource;
67
68         spin_lock_bh(&pnsocks.lock);
69
70         sk_for_each(sknode, node, &pnsocks.hlist) {
71                 struct pn_sock *pn = pn_sk(sknode);
72                 BUG_ON(!pn->sobject); /* unbound socket */
73
74                 if (!net_eq(sock_net(sknode), net))
75                         continue;
76                 if (pn_port(obj)) {
77                         /* Look up socket by port */
78                         if (pn_port(pn->sobject) != pn_port(obj))
79                                 continue;
80                 } else {
81                         /* If port is zero, look up by resource */
82                         if (pn->resource != res)
83                                 continue;
84                 }
85                 if (pn_addr(pn->sobject)
86                  && pn_addr(pn->sobject) != pn_addr(obj))
87                         continue;
88
89                 rval = sknode;
90                 sock_hold(sknode);
91                 break;
92         }
93
94         spin_unlock_bh(&pnsocks.lock);
95
96         return rval;
97
98 }
99
100 void pn_sock_hash(struct sock *sk)
101 {
102         spin_lock_bh(&pnsocks.lock);
103         sk_add_node(sk, &pnsocks.hlist);
104         spin_unlock_bh(&pnsocks.lock);
105 }
106 EXPORT_SYMBOL(pn_sock_hash);
107
108 void pn_sock_unhash(struct sock *sk)
109 {
110         spin_lock_bh(&pnsocks.lock);
111         sk_del_node_init(sk);
112         spin_unlock_bh(&pnsocks.lock);
113 }
114 EXPORT_SYMBOL(pn_sock_unhash);
115
116 static int pn_socket_bind(struct socket *sock, struct sockaddr *addr, int len)
117 {
118         struct sock *sk = sock->sk;
119         struct pn_sock *pn = pn_sk(sk);
120         struct sockaddr_pn *spn = (struct sockaddr_pn *)addr;
121         int err;
122         u16 handle;
123         u8 saddr;
124
125         if (sk->sk_prot->bind)
126                 return sk->sk_prot->bind(sk, addr, len);
127
128         if (len < sizeof(struct sockaddr_pn))
129                 return -EINVAL;
130         if (spn->spn_family != AF_PHONET)
131                 return -EAFNOSUPPORT;
132
133         handle = pn_sockaddr_get_object((struct sockaddr_pn *)addr);
134         saddr = pn_addr(handle);
135         if (saddr && phonet_address_lookup(sock_net(sk), saddr))
136                 return -EADDRNOTAVAIL;
137
138         lock_sock(sk);
139         if (sk->sk_state != TCP_CLOSE || pn_port(pn->sobject)) {
140                 err = -EINVAL; /* attempt to rebind */
141                 goto out;
142         }
143         err = sk->sk_prot->get_port(sk, pn_port(handle));
144         if (err)
145                 goto out;
146
147         /* get_port() sets the port, bind() sets the address if applicable */
148         pn->sobject = pn_object(saddr, pn_port(pn->sobject));
149         pn->resource = spn->spn_resource;
150
151         /* Enable RX on the socket */
152         sk->sk_prot->hash(sk);
153 out:
154         release_sock(sk);
155         return err;
156 }
157
158 static int pn_socket_autobind(struct socket *sock)
159 {
160         struct sockaddr_pn sa;
161         int err;
162
163         memset(&sa, 0, sizeof(sa));
164         sa.spn_family = AF_PHONET;
165         err = pn_socket_bind(sock, (struct sockaddr *)&sa,
166                                 sizeof(struct sockaddr_pn));
167         if (err != -EINVAL)
168                 return err;
169         BUG_ON(!pn_port(pn_sk(sock->sk)->sobject));
170         return 0; /* socket was already bound */
171 }
172
173 static int pn_socket_accept(struct socket *sock, struct socket *newsock,
174                                 int flags)
175 {
176         struct sock *sk = sock->sk;
177         struct sock *newsk;
178         int err;
179
180         newsk = sk->sk_prot->accept(sk, flags, &err);
181         if (!newsk)
182                 return err;
183
184         lock_sock(newsk);
185         sock_graft(newsk, newsock);
186         newsock->state = SS_CONNECTED;
187         release_sock(newsk);
188         return 0;
189 }
190
191 static int pn_socket_getname(struct socket *sock, struct sockaddr *addr,
192                                 int *sockaddr_len, int peer)
193 {
194         struct sock *sk = sock->sk;
195         struct pn_sock *pn = pn_sk(sk);
196
197         memset(addr, 0, sizeof(struct sockaddr_pn));
198         addr->sa_family = AF_PHONET;
199         if (!peer) /* Race with bind() here is userland's problem. */
200                 pn_sockaddr_set_object((struct sockaddr_pn *)addr,
201                                         pn->sobject);
202
203         *sockaddr_len = sizeof(struct sockaddr_pn);
204         return 0;
205 }
206
207 static unsigned int pn_socket_poll(struct file *file, struct socket *sock,
208                                         poll_table *wait)
209 {
210         struct sock *sk = sock->sk;
211         struct pep_sock *pn = pep_sk(sk);
212         unsigned int mask = 0;
213
214         poll_wait(file, &sock->wait, wait);
215
216         switch (sk->sk_state) {
217         case TCP_LISTEN:
218                 return hlist_empty(&pn->ackq) ? 0 : POLLIN;
219         case TCP_CLOSE:
220                 return POLLERR;
221         }
222
223         if (!skb_queue_empty(&sk->sk_receive_queue))
224                 mask |= POLLIN | POLLRDNORM;
225         if (!skb_queue_empty(&pn->ctrlreq_queue))
226                 mask |= POLLPRI;
227         if (!mask && sk->sk_state == TCP_CLOSE_WAIT)
228                 return POLLHUP;
229
230         if (sk->sk_state == TCP_ESTABLISHED && atomic_read(&pn->tx_credits))
231                 mask |= POLLOUT | POLLWRNORM | POLLWRBAND;
232
233         return mask;
234 }
235
236 static int pn_socket_ioctl(struct socket *sock, unsigned int cmd,
237                                 unsigned long arg)
238 {
239         struct sock *sk = sock->sk;
240         struct pn_sock *pn = pn_sk(sk);
241
242         if (cmd == SIOCPNGETOBJECT) {
243                 struct net_device *dev;
244                 u16 handle;
245                 u8 saddr;
246
247                 if (get_user(handle, (__u16 __user *)arg))
248                         return -EFAULT;
249
250                 lock_sock(sk);
251                 if (sk->sk_bound_dev_if)
252                         dev = dev_get_by_index(sock_net(sk),
253                                                 sk->sk_bound_dev_if);
254                 else
255                         dev = phonet_device_get(sock_net(sk));
256                 if (dev && (dev->flags & IFF_UP))
257                         saddr = phonet_address_get(dev, pn_addr(handle));
258                 else
259                         saddr = PN_NO_ADDR;
260                 release_sock(sk);
261
262                 if (dev)
263                         dev_put(dev);
264                 if (saddr == PN_NO_ADDR)
265                         return -EHOSTUNREACH;
266
267                 handle = pn_object(saddr, pn_port(pn->sobject));
268                 return put_user(handle, (__u16 __user *)arg);
269         }
270
271         return sk->sk_prot->ioctl(sk, cmd, arg);
272 }
273
274 static int pn_socket_listen(struct socket *sock, int backlog)
275 {
276         struct sock *sk = sock->sk;
277         int err = 0;
278
279         if (sock->state != SS_UNCONNECTED)
280                 return -EINVAL;
281         if (pn_socket_autobind(sock))
282                 return -ENOBUFS;
283
284         lock_sock(sk);
285         if (sk->sk_state != TCP_CLOSE) {
286                 err = -EINVAL;
287                 goto out;
288         }
289
290         sk->sk_state = TCP_LISTEN;
291         sk->sk_ack_backlog = 0;
292         sk->sk_max_ack_backlog = backlog;
293 out:
294         release_sock(sk);
295         return err;
296 }
297
298 static int pn_socket_sendmsg(struct kiocb *iocb, struct socket *sock,
299                                 struct msghdr *m, size_t total_len)
300 {
301         struct sock *sk = sock->sk;
302
303         if (pn_socket_autobind(sock))
304                 return -EAGAIN;
305
306         return sk->sk_prot->sendmsg(iocb, sk, m, total_len);
307 }
308
309 const struct proto_ops phonet_dgram_ops = {
310         .family         = AF_PHONET,
311         .owner          = THIS_MODULE,
312         .release        = pn_socket_release,
313         .bind           = pn_socket_bind,
314         .connect        = sock_no_connect,
315         .socketpair     = sock_no_socketpair,
316         .accept         = sock_no_accept,
317         .getname        = pn_socket_getname,
318         .poll           = datagram_poll,
319         .ioctl          = pn_socket_ioctl,
320         .listen         = sock_no_listen,
321         .shutdown       = sock_no_shutdown,
322         .setsockopt     = sock_no_setsockopt,
323         .getsockopt     = sock_no_getsockopt,
324 #ifdef CONFIG_COMPAT
325         .compat_setsockopt = sock_no_setsockopt,
326         .compat_getsockopt = sock_no_getsockopt,
327 #endif
328         .sendmsg        = pn_socket_sendmsg,
329         .recvmsg        = sock_common_recvmsg,
330         .mmap           = sock_no_mmap,
331         .sendpage       = sock_no_sendpage,
332 };
333
334 const struct proto_ops phonet_stream_ops = {
335         .family         = AF_PHONET,
336         .owner          = THIS_MODULE,
337         .release        = pn_socket_release,
338         .bind           = pn_socket_bind,
339         .connect        = sock_no_connect,
340         .socketpair     = sock_no_socketpair,
341         .accept         = pn_socket_accept,
342         .getname        = pn_socket_getname,
343         .poll           = pn_socket_poll,
344         .ioctl          = pn_socket_ioctl,
345         .listen         = pn_socket_listen,
346         .shutdown       = sock_no_shutdown,
347         .setsockopt     = sock_common_setsockopt,
348         .getsockopt     = sock_common_getsockopt,
349 #ifdef CONFIG_COMPAT
350         .compat_setsockopt = compat_sock_common_setsockopt,
351         .compat_getsockopt = compat_sock_common_getsockopt,
352 #endif
353         .sendmsg        = pn_socket_sendmsg,
354         .recvmsg        = sock_common_recvmsg,
355         .mmap           = sock_no_mmap,
356         .sendpage       = sock_no_sendpage,
357 };
358 EXPORT_SYMBOL(phonet_stream_ops);
359
360 static DEFINE_MUTEX(port_mutex);
361
362 /* allocate port for a socket */
363 int pn_sock_get_port(struct sock *sk, unsigned short sport)
364 {
365         static int port_cur;
366         struct net *net = sock_net(sk);
367         struct pn_sock *pn = pn_sk(sk);
368         struct sockaddr_pn try_sa;
369         struct sock *tmpsk;
370
371         memset(&try_sa, 0, sizeof(struct sockaddr_pn));
372         try_sa.spn_family = AF_PHONET;
373
374         mutex_lock(&port_mutex);
375
376         if (!sport) {
377                 /* search free port */
378                 int port, pmin, pmax;
379
380                 phonet_get_local_port_range(&pmin, &pmax);
381                 for (port = pmin; port <= pmax; port++) {
382                         port_cur++;
383                         if (port_cur < pmin || port_cur > pmax)
384                                 port_cur = pmin;
385
386                         pn_sockaddr_set_port(&try_sa, port_cur);
387                         tmpsk = pn_find_sock_by_sa(net, &try_sa);
388                         if (tmpsk == NULL) {
389                                 sport = port_cur;
390                                 goto found;
391                         } else
392                                 sock_put(tmpsk);
393                 }
394         } else {
395                 /* try to find specific port */
396                 pn_sockaddr_set_port(&try_sa, sport);
397                 tmpsk = pn_find_sock_by_sa(net, &try_sa);
398                 if (tmpsk == NULL)
399                         /* No sock there! We can use that port... */
400                         goto found;
401                 else
402                         sock_put(tmpsk);
403         }
404         mutex_unlock(&port_mutex);
405
406         /* the port must be in use already */
407         return -EADDRINUSE;
408
409 found:
410         mutex_unlock(&port_mutex);
411         pn->sobject = pn_object(pn_addr(pn->sobject), sport);
412         return 0;
413 }
414 EXPORT_SYMBOL(pn_sock_get_port);