Array#pick_one method
[rbot-mark] / mark2.rb
1 # vim: set sw=2 et:
2 # Author: Giuseppe Bilotta <giuseppe.bilotta@gmail.com>
3 # New markov chain plugin
4
5 class Array
6   def butlast
7     first(self.size-1)
8   end
9
10   def butfirst
11     last(self.size-1)
12   end
13
14   def pick_one
15     self[rand(self.size)]
16   end
17 end
18
19 class ChanceHash
20
21   def initialize
22     @hash = Hash.new(0)
23     @picker = {}
24     @total = 0
25     @valid_pick = false
26   end
27
28   def increase(el)
29     if @hash.key?(el)
30       @hash[el] += 1
31     else
32       @hash[el] = 1
33     end
34     @valid_pick = false
35     return @hash[el]
36   end
37
38   def decrease(el)
39     if @hash.key?(el)
40       @hash[el] -= 1
41       @hash.delete(el) if @hash[el] == 0
42     end
43     @valid_pick = false
44     return @hash[el]
45   end
46
47   def make_picker
48     @picker.clear
49     total = 0
50     @hash.each { |el, ch|
51       total += ch
52       @picker[total] = el
53     }
54     @total = total
55     @valid_pick = true
56   end
57
58   def random
59     case @hash.size
60     when 0
61       return nil
62     when 1
63       return @hash.keys.first
64     else
65       make_picker unless @valid_pick
66       pick = rand(@total)
67       @picker.each { |ch, el|
68         return el if pick < ch
69       }
70     end
71   end
72 end
73
74 class MarkovChainer
75   # Maximum depth
76   MAX_ORDER = 5
77
78   # Word or nonword regexp:
79   # can be used to scan a string splitting it into
80   # words and nonwords.
81   WNW = /\w+|\W/u
82
83   def initialize
84     # mkv[i] holds the chains of order i
85     @mkv = Array.new
86
87     # Each chain is in the form
88     # [:array, :of, :symbols] => {
89     #   :prev => ChanceHash,
90     #   :next => ChanceHash
91     # }
92     # except for order 0, which is a simple ChanceHash
93     # itself
94     MAX_ORDER.times { |i|
95       if i == 0
96         @mkv[0] = ChanceHash.new
97       else
98         @mkv[i] = Hash.new { |hash, key|
99           hash[key] = {:prev => ChanceHash.new, :next => ChanceHash.new}
100         }
101       end
102     }
103
104   end
105
106   def add_one(sym)
107     s = sym.to_sym rescue nil
108     @mkv[0].increase(s)
109   end
110
111   def add_before(array, prev)
112     raise "Not enough words in new data" if array.empty?
113     raise "Too many words in new data" if array.size > MAX_ORDER
114     size = array.size
115     h = @mkv[size][array.dup]
116     h[:prev].increase(prev)
117   end
118
119   def add_after(array, nxt)
120     raise "Not enough words in new data" if array.empty?
121     raise "Too many words in new data" if array.size > MAX_ORDER
122     size = array.size
123     h = @mkv[size][array.dup]
124     h[:next].increase(nxt)
125   end
126
127   def add_multi(array)
128     raise "Too many words in new data" if array.size > MAX_ORDER + 1
129     add_before(array.butfirst, array.first)
130     add_after(array.butlast, array.last)
131   end
132
133   def add(*data)
134     if data.size == 1
135       add_one(data.first)
136     else
137       add_multi(data)
138     end
139   end
140
141   def simple_learn(text)
142     syms = text.scan(WNW).map { |w| w.intern } 
143     syms.unshift(nil)
144     syms.push(nil)
145
146     syms.size.times { |i|
147       [MAX_ORDER, syms.size-i].min.times { |ord|
148         v = syms[i, ord+1]
149         # puts "Learning #{v.inspect}"
150         add(*v)
151         # pp @mkv
152       }
153     }
154   end
155
156   def learn(text, o={})
157     opts = {:lowercase => true}.merge o
158
159     lc = opts[:lowercase]
160
161     simple_learn(text)
162     if lc
163       simple_learn(text.downcase)
164     end
165
166     pp @mkv if defined? pp
167   end
168
169   def raw_next(syms)
170     ar = syms.last([MAX_ORDER, syms.size].min)
171     ord = ar.size
172     if ord == 0
173       @mkv[0].random
174     else
175       if @mkv[ord].key?(ar)
176         @mkv[ar][:next].random
177       else
178         raw_next(ar.last(ord-1))
179       end
180     end
181   end
182
183   def next(text)
184     syms = text.scan(WNW).map { |w| w.intern }
185     raw_next(syms)
186   end
187
188 end
189