Merge branch 'linus' into core/urgent
[linux-2.6] / net / ipv4 / inet_hashtables.c
1 /*
2  * INET         An implementation of the TCP/IP protocol suite for the LINUX
3  *              operating system.  INET is implemented using the BSD Socket
4  *              interface as the means of communication with the user level.
5  *
6  *              Generic INET transport hashtables
7  *
8  * Authors:     Lotsa people, from code originally in tcp
9  *
10  *      This program is free software; you can redistribute it and/or
11  *      modify it under the terms of the GNU General Public License
12  *      as published by the Free Software Foundation; either version
13  *      2 of the License, or (at your option) any later version.
14  */
15
16 #include <linux/module.h>
17 #include <linux/random.h>
18 #include <linux/sched.h>
19 #include <linux/slab.h>
20 #include <linux/wait.h>
21
22 #include <net/inet_connection_sock.h>
23 #include <net/inet_hashtables.h>
24 #include <net/ip.h>
25
26 /*
27  * Allocate and initialize a new local port bind bucket.
28  * The bindhash mutex for snum's hash chain must be held here.
29  */
30 struct inet_bind_bucket *inet_bind_bucket_create(struct kmem_cache *cachep,
31                                                  struct net *net,
32                                                  struct inet_bind_hashbucket *head,
33                                                  const unsigned short snum)
34 {
35         struct inet_bind_bucket *tb = kmem_cache_alloc(cachep, GFP_ATOMIC);
36
37         if (tb != NULL) {
38                 tb->ib_net       = hold_net(net);
39                 tb->port      = snum;
40                 tb->fastreuse = 0;
41                 INIT_HLIST_HEAD(&tb->owners);
42                 hlist_add_head(&tb->node, &head->chain);
43         }
44         return tb;
45 }
46
47 /*
48  * Caller must hold hashbucket lock for this tb with local BH disabled
49  */
50 void inet_bind_bucket_destroy(struct kmem_cache *cachep, struct inet_bind_bucket *tb)
51 {
52         if (hlist_empty(&tb->owners)) {
53                 __hlist_del(&tb->node);
54                 release_net(tb->ib_net);
55                 kmem_cache_free(cachep, tb);
56         }
57 }
58
59 void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb,
60                     const unsigned short snum)
61 {
62         inet_sk(sk)->num = snum;
63         sk_add_bind_node(sk, &tb->owners);
64         inet_csk(sk)->icsk_bind_hash = tb;
65 }
66
67 /*
68  * Get rid of any references to a local port held by the given sock.
69  */
70 static void __inet_put_port(struct sock *sk)
71 {
72         struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
73         const int bhash = inet_bhashfn(sock_net(sk), inet_sk(sk)->num,
74                         hashinfo->bhash_size);
75         struct inet_bind_hashbucket *head = &hashinfo->bhash[bhash];
76         struct inet_bind_bucket *tb;
77
78         spin_lock(&head->lock);
79         tb = inet_csk(sk)->icsk_bind_hash;
80         __sk_del_bind_node(sk);
81         inet_csk(sk)->icsk_bind_hash = NULL;
82         inet_sk(sk)->num = 0;
83         inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb);
84         spin_unlock(&head->lock);
85 }
86
87 void inet_put_port(struct sock *sk)
88 {
89         local_bh_disable();
90         __inet_put_port(sk);
91         local_bh_enable();
92 }
93
94 EXPORT_SYMBOL(inet_put_port);
95
96 void __inet_inherit_port(struct sock *sk, struct sock *child)
97 {
98         struct inet_hashinfo *table = sk->sk_prot->h.hashinfo;
99         const int bhash = inet_bhashfn(sock_net(sk), inet_sk(child)->num,
100                         table->bhash_size);
101         struct inet_bind_hashbucket *head = &table->bhash[bhash];
102         struct inet_bind_bucket *tb;
103
104         spin_lock(&head->lock);
105         tb = inet_csk(sk)->icsk_bind_hash;
106         sk_add_bind_node(child, &tb->owners);
107         inet_csk(child)->icsk_bind_hash = tb;
108         spin_unlock(&head->lock);
109 }
110
111 EXPORT_SYMBOL_GPL(__inet_inherit_port);
112
113 /*
114  * This lock without WQ_FLAG_EXCLUSIVE is good on UP and it can be very bad on SMP.
115  * Look, when several writers sleep and reader wakes them up, all but one
116  * immediately hit write lock and grab all the cpus. Exclusive sleep solves
117  * this, _but_ remember, it adds useless work on UP machines (wake up each
118  * exclusive lock release). It should be ifdefed really.
119  */
120 void inet_listen_wlock(struct inet_hashinfo *hashinfo)
121         __acquires(hashinfo->lhash_lock)
122 {
123         write_lock(&hashinfo->lhash_lock);
124
125         if (atomic_read(&hashinfo->lhash_users)) {
126                 DEFINE_WAIT(wait);
127
128                 for (;;) {
129                         prepare_to_wait_exclusive(&hashinfo->lhash_wait,
130                                                   &wait, TASK_UNINTERRUPTIBLE);
131                         if (!atomic_read(&hashinfo->lhash_users))
132                                 break;
133                         write_unlock_bh(&hashinfo->lhash_lock);
134                         schedule();
135                         write_lock_bh(&hashinfo->lhash_lock);
136                 }
137
138                 finish_wait(&hashinfo->lhash_wait, &wait);
139         }
140 }
141
142 /*
143  * Don't inline this cruft. Here are some nice properties to exploit here. The
144  * BSD API does not allow a listening sock to specify the remote port nor the
145  * remote address for the connection. So always assume those are both
146  * wildcarded during the search since they can never be otherwise.
147  */
148 static struct sock *inet_lookup_listener_slow(struct net *net,
149                                               const struct hlist_head *head,
150                                               const __be32 daddr,
151                                               const unsigned short hnum,
152                                               const int dif)
153 {
154         struct sock *result = NULL, *sk;
155         const struct hlist_node *node;
156         int hiscore = -1;
157
158         sk_for_each(sk, node, head) {
159                 const struct inet_sock *inet = inet_sk(sk);
160
161                 if (net_eq(sock_net(sk), net) && inet->num == hnum &&
162                                 !ipv6_only_sock(sk)) {
163                         const __be32 rcv_saddr = inet->rcv_saddr;
164                         int score = sk->sk_family == PF_INET ? 1 : 0;
165
166                         if (rcv_saddr) {
167                                 if (rcv_saddr != daddr)
168                                         continue;
169                                 score += 2;
170                         }
171                         if (sk->sk_bound_dev_if) {
172                                 if (sk->sk_bound_dev_if != dif)
173                                         continue;
174                                 score += 2;
175                         }
176                         if (score == 5)
177                                 return sk;
178                         if (score > hiscore) {
179                                 hiscore = score;
180                                 result  = sk;
181                         }
182                 }
183         }
184         return result;
185 }
186
187 /* Optimize the common listener case. */
188 struct sock *__inet_lookup_listener(struct net *net,
189                                     struct inet_hashinfo *hashinfo,
190                                     const __be32 daddr, const unsigned short hnum,
191                                     const int dif)
192 {
193         struct sock *sk = NULL;
194         const struct hlist_head *head;
195
196         read_lock(&hashinfo->lhash_lock);
197         head = &hashinfo->listening_hash[inet_lhashfn(net, hnum)];
198         if (!hlist_empty(head)) {
199                 const struct inet_sock *inet = inet_sk((sk = __sk_head(head)));
200
201                 if (inet->num == hnum && !sk->sk_node.next &&
202                     (!inet->rcv_saddr || inet->rcv_saddr == daddr) &&
203                     (sk->sk_family == PF_INET || !ipv6_only_sock(sk)) &&
204                     !sk->sk_bound_dev_if && net_eq(sock_net(sk), net))
205                         goto sherry_cache;
206                 sk = inet_lookup_listener_slow(net, head, daddr, hnum, dif);
207         }
208         if (sk) {
209 sherry_cache:
210                 sock_hold(sk);
211         }
212         read_unlock(&hashinfo->lhash_lock);
213         return sk;
214 }
215 EXPORT_SYMBOL_GPL(__inet_lookup_listener);
216
217 struct sock * __inet_lookup_established(struct net *net,
218                                   struct inet_hashinfo *hashinfo,
219                                   const __be32 saddr, const __be16 sport,
220                                   const __be32 daddr, const u16 hnum,
221                                   const int dif)
222 {
223         INET_ADDR_COOKIE(acookie, saddr, daddr)
224         const __portpair ports = INET_COMBINED_PORTS(sport, hnum);
225         struct sock *sk;
226         const struct hlist_node *node;
227         /* Optimize here for direct hit, only listening connections can
228          * have wildcards anyways.
229          */
230         unsigned int hash = inet_ehashfn(net, daddr, hnum, saddr, sport);
231         struct inet_ehash_bucket *head = inet_ehash_bucket(hashinfo, hash);
232         rwlock_t *lock = inet_ehash_lockp(hashinfo, hash);
233
234         prefetch(head->chain.first);
235         read_lock(lock);
236         sk_for_each(sk, node, &head->chain) {
237                 if (INET_MATCH(sk, net, hash, acookie,
238                                         saddr, daddr, ports, dif))
239                         goto hit; /* You sunk my battleship! */
240         }
241
242         /* Must check for a TIME_WAIT'er before going to listener hash. */
243         sk_for_each(sk, node, &head->twchain) {
244                 if (INET_TW_MATCH(sk, net, hash, acookie,
245                                         saddr, daddr, ports, dif))
246                         goto hit;
247         }
248         sk = NULL;
249 out:
250         read_unlock(lock);
251         return sk;
252 hit:
253         sock_hold(sk);
254         goto out;
255 }
256 EXPORT_SYMBOL_GPL(__inet_lookup_established);
257
258 /* called with local bh disabled */
259 static int __inet_check_established(struct inet_timewait_death_row *death_row,
260                                     struct sock *sk, __u16 lport,
261                                     struct inet_timewait_sock **twp)
262 {
263         struct inet_hashinfo *hinfo = death_row->hashinfo;
264         struct inet_sock *inet = inet_sk(sk);
265         __be32 daddr = inet->rcv_saddr;
266         __be32 saddr = inet->daddr;
267         int dif = sk->sk_bound_dev_if;
268         INET_ADDR_COOKIE(acookie, saddr, daddr)
269         const __portpair ports = INET_COMBINED_PORTS(inet->dport, lport);
270         struct net *net = sock_net(sk);
271         unsigned int hash = inet_ehashfn(net, daddr, lport, saddr, inet->dport);
272         struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash);
273         rwlock_t *lock = inet_ehash_lockp(hinfo, hash);
274         struct sock *sk2;
275         const struct hlist_node *node;
276         struct inet_timewait_sock *tw;
277
278         prefetch(head->chain.first);
279         write_lock(lock);
280
281         /* Check TIME-WAIT sockets first. */
282         sk_for_each(sk2, node, &head->twchain) {
283                 tw = inet_twsk(sk2);
284
285                 if (INET_TW_MATCH(sk2, net, hash, acookie,
286                                         saddr, daddr, ports, dif)) {
287                         if (twsk_unique(sk, sk2, twp))
288                                 goto unique;
289                         else
290                                 goto not_unique;
291                 }
292         }
293         tw = NULL;
294
295         /* And established part... */
296         sk_for_each(sk2, node, &head->chain) {
297                 if (INET_MATCH(sk2, net, hash, acookie,
298                                         saddr, daddr, ports, dif))
299                         goto not_unique;
300         }
301
302 unique:
303         /* Must record num and sport now. Otherwise we will see
304          * in hash table socket with a funny identity. */
305         inet->num = lport;
306         inet->sport = htons(lport);
307         sk->sk_hash = hash;
308         WARN_ON(!sk_unhashed(sk));
309         __sk_add_node(sk, &head->chain);
310         sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
311         write_unlock(lock);
312
313         if (twp) {
314                 *twp = tw;
315                 NET_INC_STATS_BH(net, LINUX_MIB_TIMEWAITRECYCLED);
316         } else if (tw) {
317                 /* Silly. Should hash-dance instead... */
318                 inet_twsk_deschedule(tw, death_row);
319                 NET_INC_STATS_BH(net, LINUX_MIB_TIMEWAITRECYCLED);
320
321                 inet_twsk_put(tw);
322         }
323
324         return 0;
325
326 not_unique:
327         write_unlock(lock);
328         return -EADDRNOTAVAIL;
329 }
330
331 static inline u32 inet_sk_port_offset(const struct sock *sk)
332 {
333         const struct inet_sock *inet = inet_sk(sk);
334         return secure_ipv4_port_ephemeral(inet->rcv_saddr, inet->daddr,
335                                           inet->dport);
336 }
337
338 void __inet_hash_nolisten(struct sock *sk)
339 {
340         struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
341         struct hlist_head *list;
342         rwlock_t *lock;
343         struct inet_ehash_bucket *head;
344
345         WARN_ON(!sk_unhashed(sk));
346
347         sk->sk_hash = inet_sk_ehashfn(sk);
348         head = inet_ehash_bucket(hashinfo, sk->sk_hash);
349         list = &head->chain;
350         lock = inet_ehash_lockp(hashinfo, sk->sk_hash);
351
352         write_lock(lock);
353         __sk_add_node(sk, list);
354         sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
355         write_unlock(lock);
356 }
357 EXPORT_SYMBOL_GPL(__inet_hash_nolisten);
358
359 static void __inet_hash(struct sock *sk)
360 {
361         struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
362         struct hlist_head *list;
363         rwlock_t *lock;
364
365         if (sk->sk_state != TCP_LISTEN) {
366                 __inet_hash_nolisten(sk);
367                 return;
368         }
369
370         WARN_ON(!sk_unhashed(sk));
371         list = &hashinfo->listening_hash[inet_sk_listen_hashfn(sk)];
372         lock = &hashinfo->lhash_lock;
373
374         inet_listen_wlock(hashinfo);
375         __sk_add_node(sk, list);
376         sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
377         write_unlock(lock);
378         wake_up(&hashinfo->lhash_wait);
379 }
380
381 void inet_hash(struct sock *sk)
382 {
383         if (sk->sk_state != TCP_CLOSE) {
384                 local_bh_disable();
385                 __inet_hash(sk);
386                 local_bh_enable();
387         }
388 }
389 EXPORT_SYMBOL_GPL(inet_hash);
390
391 void inet_unhash(struct sock *sk)
392 {
393         rwlock_t *lock;
394         struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
395
396         if (sk_unhashed(sk))
397                 goto out;
398
399         if (sk->sk_state == TCP_LISTEN) {
400                 local_bh_disable();
401                 inet_listen_wlock(hashinfo);
402                 lock = &hashinfo->lhash_lock;
403         } else {
404                 lock = inet_ehash_lockp(hashinfo, sk->sk_hash);
405                 write_lock_bh(lock);
406         }
407
408         if (__sk_del_node_init(sk))
409                 sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
410         write_unlock_bh(lock);
411 out:
412         if (sk->sk_state == TCP_LISTEN)
413                 wake_up(&hashinfo->lhash_wait);
414 }
415 EXPORT_SYMBOL_GPL(inet_unhash);
416
417 int __inet_hash_connect(struct inet_timewait_death_row *death_row,
418                 struct sock *sk, u32 port_offset,
419                 int (*check_established)(struct inet_timewait_death_row *,
420                         struct sock *, __u16, struct inet_timewait_sock **),
421                 void (*hash)(struct sock *sk))
422 {
423         struct inet_hashinfo *hinfo = death_row->hashinfo;
424         const unsigned short snum = inet_sk(sk)->num;
425         struct inet_bind_hashbucket *head;
426         struct inet_bind_bucket *tb;
427         int ret;
428         struct net *net = sock_net(sk);
429
430         if (!snum) {
431                 int i, remaining, low, high, port;
432                 static u32 hint;
433                 u32 offset = hint + port_offset;
434                 struct hlist_node *node;
435                 struct inet_timewait_sock *tw = NULL;
436
437                 inet_get_local_port_range(&low, &high);
438                 remaining = (high - low) + 1;
439
440                 local_bh_disable();
441                 for (i = 1; i <= remaining; i++) {
442                         port = low + (i + offset) % remaining;
443                         head = &hinfo->bhash[inet_bhashfn(net, port,
444                                         hinfo->bhash_size)];
445                         spin_lock(&head->lock);
446
447                         /* Does not bother with rcv_saddr checks,
448                          * because the established check is already
449                          * unique enough.
450                          */
451                         inet_bind_bucket_for_each(tb, node, &head->chain) {
452                                 if (tb->ib_net == net && tb->port == port) {
453                                         WARN_ON(hlist_empty(&tb->owners));
454                                         if (tb->fastreuse >= 0)
455                                                 goto next_port;
456                                         if (!check_established(death_row, sk,
457                                                                 port, &tw))
458                                                 goto ok;
459                                         goto next_port;
460                                 }
461                         }
462
463                         tb = inet_bind_bucket_create(hinfo->bind_bucket_cachep,
464                                         net, head, port);
465                         if (!tb) {
466                                 spin_unlock(&head->lock);
467                                 break;
468                         }
469                         tb->fastreuse = -1;
470                         goto ok;
471
472                 next_port:
473                         spin_unlock(&head->lock);
474                 }
475                 local_bh_enable();
476
477                 return -EADDRNOTAVAIL;
478
479 ok:
480                 hint += i;
481
482                 /* Head lock still held and bh's disabled */
483                 inet_bind_hash(sk, tb, port);
484                 if (sk_unhashed(sk)) {
485                         inet_sk(sk)->sport = htons(port);
486                         hash(sk);
487                 }
488                 spin_unlock(&head->lock);
489
490                 if (tw) {
491                         inet_twsk_deschedule(tw, death_row);
492                         inet_twsk_put(tw);
493                 }
494
495                 ret = 0;
496                 goto out;
497         }
498
499         head = &hinfo->bhash[inet_bhashfn(net, snum, hinfo->bhash_size)];
500         tb  = inet_csk(sk)->icsk_bind_hash;
501         spin_lock_bh(&head->lock);
502         if (sk_head(&tb->owners) == sk && !sk->sk_bind_node.next) {
503                 hash(sk);
504                 spin_unlock_bh(&head->lock);
505                 return 0;
506         } else {
507                 spin_unlock(&head->lock);
508                 /* No definite answer... Walk to established hash table */
509                 ret = check_established(death_row, sk, snum, NULL);
510 out:
511                 local_bh_enable();
512                 return ret;
513         }
514 }
515
516 /*
517  * Bind a port for a connect operation and hash it.
518  */
519 int inet_hash_connect(struct inet_timewait_death_row *death_row,
520                       struct sock *sk)
521 {
522         return __inet_hash_connect(death_row, sk, inet_sk_port_offset(sk),
523                         __inet_check_established, __inet_hash_nolisten);
524 }
525
526 EXPORT_SYMBOL_GPL(inet_hash_connect);