Solve Socket vs URI IPv6 handling in Ruby
[rbot] / lib / rbot / ircsocket.rb
1 #-- vim:sw=2:et
2 #++
3 #
4 # :title: IRC Socket
5 #
6 # This module implements the IRC socket interface, including IRC message
7 # penalty computation and the message queue system
8
9 require 'monitor'
10
11 class ::String
12   # Calculate the penalty which will be assigned to this message
13   # by the IRCd
14   def irc_send_penalty
15     # According to eggdrop, the initial penalty is
16     penalty = 1 + self.size/100
17     # on everything but UnderNET where it's
18     # penalty = 2 + self.size/120
19
20     cmd, pars = self.split($;,2)
21     debug "cmd: #{cmd}, pars: #{pars.inspect}"
22     case cmd.to_sym
23     when :KICK
24       chan, nick, msg = pars.split
25       chan = chan.split(',')
26       nick = nick.split(',')
27       penalty += nick.size
28       penalty *= chan.size
29     when :MODE
30       chan, modes, argument = pars.split
31       extra = 0
32       if modes
33         extra = 1
34         if argument
35           extra += modes.split(/\+|-/).size
36         else
37           extra += 3 * modes.split(/\+|-/).size
38         end
39       end
40       if argument
41         extra += 2 * argument.split.size
42       end
43       penalty += extra * chan.split.size
44     when :TOPIC
45       penalty += 1
46       penalty += 2 unless pars.split.size < 2
47     when :PRIVMSG, :NOTICE
48       dests = pars.split($;,2).first
49       penalty += dests.split(',').size
50     when :WHO
51       args = pars.split
52       if args.length > 0
53         penalty += args.inject(0){ |sum,x| sum += ((x.length > 4) ? 3 : 5) }
54       else
55         penalty += 10
56       end
57     when :PART
58       penalty += 4
59     when :AWAY, :JOIN, :VERSION, :TIME, :TRACE, :WHOIS, :DNS
60       penalty += 2
61     when :INVITE, :NICK
62       penalty += 3
63     when :ISON
64       penalty += 1
65     else # Unknown messages
66       penalty += 1
67     end
68     if penalty > 99
69       debug "Wow, more than 99 secs of penalty!"
70       penalty = 99
71     end
72     if penalty < 2
73       debug "Wow, less than 2 secs of penalty!"
74       penalty = 2
75     end
76     debug "penalty: #{penalty}"
77     return penalty
78   end
79 end
80
81 module Irc
82
83   require 'socket'
84   require 'thread'
85
86   class QueueRing
87     # A QueueRing is implemented as an array with elements in the form
88     # [chan, [message1, message2, ...]
89     # Note that the channel +chan+ has no actual bearing with the channels
90     # to which messages will be sent
91
92     def initialize
93       @storage = Array.new
94       @last_idx = -1
95     end
96
97     def clear
98       @storage.clear
99       @last_idx = -1
100     end
101
102     def length
103       len = 0
104       @storage.each {|c|
105         len += c[1].size
106       }
107       return len
108     end
109     alias :size :length
110
111     def empty?
112       @storage.empty?
113     end
114
115     def push(mess, chan)
116       cmess = @storage.assoc(chan)
117       if cmess
118         idx = @storage.index(cmess)
119         cmess[1] << mess
120         @storage[idx] = cmess
121       else
122         @storage << [chan, [mess]]
123       end
124     end
125
126     def next
127       if empty?
128         warning "trying to access empty ring"
129         return nil
130       end
131       save_idx = @last_idx
132       @last_idx = (@last_idx + 1) % @storage.size
133       mess = @storage[@last_idx][1].first
134       @last_idx = save_idx
135       return mess
136     end
137
138     def shift
139       if empty?
140         warning "trying to access empty ring"
141         return nil
142       end
143       @last_idx = (@last_idx + 1) % @storage.size
144       mess = @storage[@last_idx][1].shift
145       @storage.delete(@storage[@last_idx]) if @storage[@last_idx][1] == []
146       return mess
147     end
148
149   end
150
151   class MessageQueue
152
153     def initialize
154       # a MessageQueue is an array of QueueRings
155       # rings have decreasing priority, so messages in ring 0
156       # are more important than messages in ring 1, and so on
157       @rings = Array.new(3) { |i|
158         if i > 0
159           QueueRing.new
160         else
161           # ring 0 is special in that if it's not empty, it will
162           # be popped. IOW, ring 0 can starve the other rings
163           # ring 0 is strictly FIFO and is therefore implemented
164           # as an array
165           Array.new
166         end
167       }
168       # the other rings are satisfied round-robin
169       @last_ring = 0
170       self.extend(MonitorMixin)
171       @non_empty = self.new_cond
172     end
173
174     def clear
175       self.synchronize do
176         @rings.each { |r| r.clear }
177         @last_ring = 0
178       end
179     end
180
181     def push(mess, chan=nil, cring=0)
182       ring = cring
183       self.synchronize do
184         if ring == 0
185           warning "message #{mess} at ring 0 has channel #{chan}: channel will be ignored" if !chan.nil?
186           @rings[0] << mess
187         else
188           error "message #{mess} at ring #{ring} must have a channel" if chan.nil?
189           @rings[ring].push mess, chan
190         end
191         @non_empty.signal
192       end
193     end
194
195     def shift(tmout = nil)
196       self.synchronize do
197         @non_empty.wait(tmout) if self.empty?
198         return unsafe_shift
199       end
200     end
201
202     protected
203
204     def empty?
205       !@rings.find { |r| !r.empty? }
206     end
207
208     def length
209       @rings.inject(0) { |s, r| s + r.size }
210     end
211     alias :size :length
212
213     def unsafe_shift
214       if !@rings[0].empty?
215         return @rings[0].shift
216       end
217       (@rings.size - 1).times do
218         @last_ring = (@last_ring % (@rings.size - 1)) + 1
219         return @rings[@last_ring].shift unless @rings[@last_ring].empty?
220       end
221       warning "trying to access an empty message queue"
222       return nil
223     end
224
225   end
226
227   # wrapped TCPSocket for communication with the server.
228   # emulates a subset of TCPSocket functionality
229   class Socket
230
231     MAX_IRC_SEND_PENALTY = 10
232
233     # total number of lines sent to the irc server
234     attr_reader :lines_sent
235
236     # total number of lines received from the irc server
237     attr_reader :lines_received
238
239     # total number of bytes sent to the irc server
240     attr_reader :bytes_sent
241
242     # total number of bytes received from the irc server
243     attr_reader :bytes_received
244
245     # accumulator for the throttle
246     attr_reader :throttle_bytes
247
248     # an optional filter object. we call @filter.in(data) for
249     # all incoming data and @filter.out(data) for all outgoing data
250     attr_reader :filter
251
252     # normalized uri of the current server
253     attr_reader :server_uri
254
255     # penalty multiplier (percent)
256     attr_accessor :penalty_pct
257
258     # default trivial filter class
259     class IdentityFilter
260         def in(x)
261             x
262         end
263
264         def out(x)
265             x
266         end
267     end
268
269     # set filter to identity, not to nil
270     def filter=(f)
271         @filter = f || IdentityFilter.new
272     end
273
274     # server_list:: list of servers to connect to
275     # host::   optional local host to bind to (ruby 1.7+ required)
276     # create a new Irc::Socket
277     def initialize(server_list, host, opts={})
278       @server_list = server_list.dup
279       @server_uri = nil
280       @conn_count = 0
281       @host = host
282       @sock = nil
283       @filter = IdentityFilter.new
284       @spooler = false
285       @lines_sent = 0
286       @lines_received = 0
287       @ssl = opts[:ssl]
288       @penalty_pct = opts[:penalty_pct] || 100
289     end
290
291     def connected?
292       !@sock.nil?
293     end
294
295     # open a TCP connection to the server
296     def connect
297       if connected?
298         warning "reconnecting while connected"
299         return
300       end
301       srv_uri = @server_list[@conn_count % @server_list.size].dup
302       srv_uri = 'irc://' + srv_uri if !(srv_uri =~ /:\/\//)
303       @conn_count += 1
304       @server_uri = URI.parse(srv_uri)
305       @server_uri.port = 6667 if !@server_uri.port
306
307       debug "connection attempt \##{@conn_count} (#{@server_uri.host}:#{@server_uri.port})"
308
309       # if the host is a bracketed (IPv6) address, strip the brackets
310       # since Ruby doesn't like them in the Socket host parameter
311       # FIXME it would be safer to have it check for a valid
312       # IPv6 bracketed address rather than just stripping the brackets
313       srv_host = @server_uri.host
314       if srv_host.match(/\A\[(.*)\]\z/)
315         srv_host = $1
316       end
317
318       if(@host)
319         begin
320           sock=TCPSocket.new(srv_host, @server_uri.port, @host)
321         rescue ArgumentError => e
322           error "Your version of ruby does not support binding to a "
323           error "specific local address, please upgrade if you wish "
324           error "to use HOST = foo"
325           error "(this option has been disabled in order to continue)"
326           sock=TCPSocket.new(srv_host, @server_uri.port)
327         end
328       else
329         sock=TCPSocket.new(srv_host, @server_uri.port)
330       end
331       if(@ssl)
332         require 'openssl'
333         ssl_context = OpenSSL::SSL::SSLContext.new()
334         ssl_context.verify_mode = OpenSSL::SSL::VERIFY_NONE
335         sock = OpenSSL::SSL::SSLSocket.new(sock, ssl_context)
336         sock.sync_close = true
337         sock.connect
338       end
339       @sock = sock
340       @last_send = Time.new
341       @flood_send = Time.new
342       @burst = 0
343       @sock.extend(MonitorMixin)
344       @sendq = MessageQueue.new
345       @qthread = Thread.new { writer_loop }
346     end
347
348     # used to send lines to the remote IRCd by skipping the queue
349     # message: IRC message to send
350     # it should only be used for stuff that *must not* be queued,
351     # i.e. the initial PASS, NICK and USER command
352     # or the final QUIT message
353     def emergency_puts(message, penalty = false)
354       @sock.synchronize do
355         # debug "In puts - got @sock"
356         puts_critical(message, penalty)
357       end
358     end
359
360     def handle_socket_error(string, e)
361       error "#{string} failed: #{e.pretty_inspect}"
362       # We assume that an error means that there are connection
363       # problems and that we should reconnect, so we
364       shutdown
365       raise SocketError.new(e.inspect)
366     end
367
368     # get the next line from the server (blocks)
369     def gets
370       if @sock.nil?
371         warning "socket get attempted while closed"
372         return nil
373       end
374       begin
375         reply = @filter.in(@sock.gets)
376         @lines_received += 1
377         reply.strip! if reply
378         debug "RECV: #{reply.inspect}"
379         return reply
380       rescue Exception => e
381         handle_socket_error(:RECV, e)
382       end
383     end
384
385     def queue(msg, chan=nil, ring=0)
386       @sendq.push msg, chan, ring
387     end
388
389     def clearq
390       @sendq.clear
391     end
392
393     # flush the TCPSocket
394     def flush
395       @sock.flush
396     end
397
398     # Wraps Kernel.select on the socket
399     def select(timeout=nil)
400       Kernel.select([@sock], nil, nil, timeout)
401     end
402
403     # shutdown the connection to the server
404     def shutdown(how=2)
405       return unless connected?
406       @qthread.kill
407       @qthread = nil
408       begin
409         @sock.close
410       rescue Exception => e
411         error "error while shutting down: #{e.pretty_inspect}"
412       end
413       @sock = nil
414       @server_uri = nil
415       @sendq.clear
416     end
417
418     private
419
420     def writer_loop
421       loop do
422         begin
423           now = Time.now
424           flood_delay = @flood_send - MAX_IRC_SEND_PENALTY - now
425           delay = [flood_delay, 0].max
426           if delay > 0
427             debug "sleep(#{delay}) # (f: #{flood_delay})"
428             sleep(delay)
429           end
430           msg = @sendq.shift
431           debug "got #{msg.inspect} from queue, sending"
432           emergency_puts(msg, true)
433         rescue Exception => e
434           error "Spooling failed: #{e.pretty_inspect}"
435           debug e.backtrace.join("\n")
436           raise e
437         end
438       end
439     end
440
441     # same as puts, but expects to be called with a lock held on @sock
442     def puts_critical(message, penalty=false)
443       # debug "in puts_critical"
444       begin
445         debug "SEND: #{message.inspect}"
446         if @sock.nil?
447           error "SEND attempted on closed socket"
448         else
449           # we use Socket#syswrite() instead of Socket#puts() because
450           # the latter is racy and can cause double message output in
451           # some circumstances
452           actual = @filter.out(message) + "\n"
453           now = Time.new
454           @sock.syswrite actual
455           @last_send = now
456           @flood_send = now if @flood_send < now
457           @flood_send += message.irc_send_penalty*@penalty_pct/100.0 if penalty
458           @lines_sent += 1
459         end
460       rescue Exception => e
461         handle_socket_error(:SEND, e)
462       end
463     end
464
465   end
466
467 end