presage  0.8.8
smoothedNgramPredictor.cpp
Go to the documentation of this file.
1 
2 /******************************************************
3  * Presage, an extensible predictive text entry system
4  * ---------------------------------------------------
5  *
6  * Copyright (C) 2008 Matteo Vescovi <matteo.vescovi@yahoo.co.uk>
7 
8  This program is free software; you can redistribute it and/or modify
9  it under the terms of the GNU General Public License as published by
10  the Free Software Foundation; either version 2 of the License, or
11  (at your option) any later version.
12 
13  This program is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16  GNU General Public License for more details.
17 
18  You should have received a copy of the GNU General Public License along
19  with this program; if not, write to the Free Software Foundation, Inc.,
20  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21  *
22  **********(*)*/
23 
24 
25 #include "smoothedNgramPredictor.h"
26 
27 #include <sstream>
28 #include <algorithm>
29 
30 
32  : Predictor(config,
33  ct,
34  name,
35  "SmoothedNgramPredictor, a linear interpolating n-gram predictor",
36  "SmoothedNgramPredictor, long description." ),
37  db (0),
38  cardinality (0),
39  learn_mode_set (false),
40  dispatcher (this)
41 {
42  LOGGER = PREDICTORS + name + ".LOGGER";
43  DBFILENAME = PREDICTORS + name + ".DBFILENAME";
44  DELTAS = PREDICTORS + name + ".DELTAS";
45  LEARN = PREDICTORS + name + ".LEARN";
46  DATABASE_LOGGER = PREDICTORS + name + ".DatabaseConnector.LOGGER";
47 
48  // build notification dispatch map
54 }
55 
56 
57 
59 {
60  delete db;
61 }
62 
63 
64 void SmoothedNgramPredictor::set_dbfilename (const std::string& filename)
65 {
66  dbfilename = filename;
67  logger << INFO << "DBFILENAME: " << dbfilename << endl;
68 
70 }
71 
72 
74 {
75  dbloglevel = value;
76 }
77 
78 
79 void SmoothedNgramPredictor::set_deltas (const std::string& value)
80 {
81  std::stringstream ss_deltas(value);
82  cardinality = 0;
83  std::string delta;
84  while (ss_deltas >> delta) {
85  logger << DEBUG << "Pushing delta: " << delta << endl;
86  deltas.push_back (Utility::toDouble (delta));
87  cardinality++;
88  }
89  logger << INFO << "DELTAS: " << value << endl;
90  logger << INFO << "CARDINALITY: " << cardinality << endl;
91 
93 }
94 
95 
96 void SmoothedNgramPredictor::set_learn (const std::string& value)
97 {
98  wanna_learn = Utility::isTrue (value);
99  logger << INFO << "LEARN: " << value << endl;
100 
101  learn_mode_set = true;
102 
104 }
105 
106 
108 {
109  // we can only init the sqlite database connector once we know the
110  // following:
111  // - what database file we need to open
112  // - what cardinality we expect the database file to be
113  // - whether we need to open the database in read only or
114  // read/write mode (learning requires read/write access)
115  //
116  if (! dbfilename.empty()
117  && cardinality > 0
118  && learn_mode_set ) {
119 
120  delete db;
121 
122  if (dbloglevel.empty ()) {
123  // open database connector
125  cardinality,
126  wanna_learn);
127  } else {
128  // open database connector with logger lever
130  cardinality,
131  wanna_learn,
132  dbloglevel);
133  }
134  }
135 }
136 
137 
138 // convenience function to convert ngram to string
139 //
140 static std::string ngram_to_string(const Ngram& ngram)
141 {
142  const char separator[] = "|";
143  std::string result = separator;
144 
145  for (Ngram::const_iterator it = ngram.begin();
146  it != ngram.end();
147  it++)
148  {
149  result += *it + separator;
150  }
151 
152  return result;
153 }
154 
155 
171 unsigned int SmoothedNgramPredictor::count(const std::vector<std::string>& tokens, int offset, int ngram_size) const
172 {
173  unsigned int result = 0;
174 
175  assert(offset <= 0); // TODO: handle this better
176  assert(ngram_size >= 0);
177 
178  if (ngram_size > 0) {
179  Ngram ngram(ngram_size);
180  copy(tokens.end() - ngram_size + offset , tokens.end() + offset, ngram.begin());
181  result = db->getNgramCount(ngram);
182  logger << DEBUG << "count ngram: " << ngram_to_string (ngram) << " : " << result << endl;
183  } else {
184  result = db->getUnigramCountsSum();
185  logger << DEBUG << "unigram counts sum: " << result << endl;
186  }
187 
188  return result;
189 }
190 
191 Prediction SmoothedNgramPredictor::predict(const size_t max_partial_prediction_size, const char** filter) const
192 {
193  logger << DEBUG << "predict()" << endl;
194 
195  // Result prediction
196  Prediction prediction;
197 
198  // Cache all the needed tokens.
199  // tokens[k] corresponds to w_{i-k} in the generalized smoothed
200  // n-gram probability formula
201  //
202  std::vector<std::string> tokens(cardinality);
203  for (int i = 0; i < cardinality; i++) {
204  tokens[cardinality - 1 - i] = contextTracker->getToken(i);
205  logger << DEBUG << "Cached tokens[" << cardinality - 1 - i << "] = " << tokens[cardinality - 1 - i] << endl;
206  }
207 
208  // Generate list of prefix completition candidates.
209  //
210  // The prefix completion candidates used to be obtained from the
211  // _1_gram table because in a well-constructed ngram database the
212  // _1_gram table (which contains all known tokens). However, this
213  // introduced a skew, since the unigram counts will take
214  // precedence over the higher-order counts.
215  //
216  // The current solution retrieves candidates from the highest
217  // n-gram table, falling back on lower order n-gram tables if
218  // initial completion set is smaller than required.
219  //
220  std::vector<std::string> prefixCompletionCandidates;
221  for (size_t k = cardinality; (k > 0 && prefixCompletionCandidates.size() < max_partial_prediction_size); k--) {
222  logger << DEBUG << "Building partial prefix completion table of cardinality: " << k << endl;
223  // create n-gram used to retrieve initial prefix completion table
224  Ngram prefix_ngram(k);
225  copy(tokens.end() - k, tokens.end(), prefix_ngram.begin());
226 
227  if (logger.shouldLog()) {
228  logger << DEBUG << "prefix_ngram: ";
229  for (size_t r = 0; r < prefix_ngram.size(); r++) {
230  logger << DEBUG << prefix_ngram[r] << ' ';
231  }
232  logger << DEBUG << endl;
233  }
234 
235  // obtain initial prefix completion candidates
236  db->beginTransaction();
237 
238  NgramTable partial;
239 
240  if (filter == 0) {
241  partial = db->getNgramLikeTable(prefix_ngram,max_partial_prediction_size - prefixCompletionCandidates.size());
242  } else {
243  partial = db->getNgramLikeTableFiltered(prefix_ngram,filter, max_partial_prediction_size - prefixCompletionCandidates.size());
244  }
245 
246  db->endTransaction();
247 
248  if (logger.shouldLog()) {
249  logger << DEBUG << "partial prefixCompletionCandidates" << endl
250  << DEBUG << "----------------------------------" << endl;
251  for (size_t j = 0; j < partial.size(); j++) {
252  for (size_t k = 0; k < partial[j].size(); k++) {
253  logger << DEBUG << partial[j][k] << " ";
254  }
255  logger << endl;
256  }
257  }
258 
259  logger << DEBUG << "Partial prefix completion table contains " << partial.size() << " potential completions." << endl;
260 
261  // append newly discovered potential completions to prefix
262  // completion candidates array to fill it up to
263  // max_partial_prediction_size
264  //
265  std::vector<Ngram>::const_iterator it = partial.begin();
266  while (it != partial.end() && prefixCompletionCandidates.size() < max_partial_prediction_size) {
267  // only add new candidates, iterator it points to Ngram,
268  // it->end() - 2 points to the token candidate
269  //
270  std::string candidate = *(it->end() - 2);
271  if (find(prefixCompletionCandidates.begin(),
272  prefixCompletionCandidates.end(),
273  candidate) == prefixCompletionCandidates.end()) {
274  prefixCompletionCandidates.push_back(candidate);
275  }
276  it++;
277  }
278  }
279 
280  if (logger.shouldLog()) {
281  logger << DEBUG << "prefixCompletionCandidates" << endl
282  << DEBUG << "--------------------------" << endl;
283  for (size_t j = 0; j < prefixCompletionCandidates.size(); j++) {
284  logger << DEBUG << prefixCompletionCandidates[j] << endl;
285  }
286  }
287 
288  // compute smoothed probabilities for all candidates
289  //
290  db->beginTransaction();
291  // getUnigramCountsSum is an expensive SQL query
292  // caching it here saves much time later inside the loop
293  int unigrams_counts_sum = db->getUnigramCountsSum();
294  for (size_t j = 0; (j < prefixCompletionCandidates.size() && j < max_partial_prediction_size); j++) {
295  // store w_i candidate at end of tokens
296  tokens[cardinality - 1] = prefixCompletionCandidates[j];
297 
298  logger << DEBUG << "------------------" << endl;
299  logger << DEBUG << "w_i: " << tokens[cardinality - 1] << endl;
300 
301  double probability = 0;
302  for (int k = 0; k < cardinality; k++) {
303  double numerator = count(tokens, 0, k+1);
304  // reuse cached unigrams_counts_sum to speed things up
305  double denominator = (k == 0 ? unigrams_counts_sum : count(tokens, -1, k));
306  double frequency = ((denominator > 0) ? (numerator / denominator) : 0);
307  probability += deltas[k] * frequency;
308 
309  logger << DEBUG << "numerator: " << numerator << endl;
310  logger << DEBUG << "denominator: " << denominator << endl;
311  logger << DEBUG << "frequency: " << frequency << endl;
312  logger << DEBUG << "delta: " << deltas[k] << endl;
313 
314  // for some sanity checks
315  assert(numerator <= denominator);
316  assert(frequency <= 1);
317  }
318 
319  logger << DEBUG << "____________" << endl;
320  logger << DEBUG << "probability: " << probability << endl;
321 
322  if (probability > 0) {
323  prediction.addSuggestion(Suggestion(tokens[cardinality - 1], probability));
324  }
325  }
326  db->endTransaction();
327 
328  logger << DEBUG << "Prediction:" << endl;
329  logger << DEBUG << "-----------" << endl;
330  logger << DEBUG << prediction << endl;
331 
332  return prediction;
333 }
334 
335 void SmoothedNgramPredictor::learn(const std::vector<std::string>& change)
336 {
337  logger << INFO << "learn(\"" << ngram_to_string(change) << "\")" << endl;
338 
339  if (wanna_learn) {
340  // learning is turned on
341 
342  try
343  {
344  db->beginTransaction();
345 
346  for (size_t curr_cardinality = 1;
347  curr_cardinality < cardinality + 1;
348  curr_cardinality++) {
349 
350  logger << DEBUG << "Learning for n-gram cardinality: " << curr_cardinality << endl;
351 
352  // idx walks the change vector back to front
353  for (std::vector<std::string>::const_reverse_iterator idx = change.rbegin();
354  idx != change.rend();
355  idx++)
356  {
357  Ngram ngram;
358 
359  // try to fill in the ngram to be learnt with change
360  // tokens first
361  for (std::vector<std::string>::const_reverse_iterator inner_idx = idx;
362  inner_idx != change.rend() && ngram.size() < curr_cardinality;
363  inner_idx++)
364  {
365  ngram.insert(ngram.begin(), *inner_idx);
366  }
367 
368  logger << DEBUG << "After filling n-gram with change tokens: " << ngram_to_string(ngram) << endl;
369 
370  // then use (past stream - change) if ngram not filled in yet
371  for (int tk_idx = 1;
372  ngram.size() < curr_cardinality;
373  tk_idx++)
374  {
375  // getExtraTokenToLearn returns tokens from
376  // past stream that come before and are not in
377  // change vector
378  //
379  std::string extra_token = contextTracker->getExtraTokenToLearn(tk_idx, change);
380  logger << DEBUG << "Adding extra token: " << extra_token << endl;
381  ngram.insert(ngram.begin(), extra_token);
382  }
383 
384  // now we have built the ngram we have to learn
385  logger << INFO << "Considering to learn ngram: |";
386  for (size_t j = 0; j < ngram.size(); j++) {
387  logger << INFO << ngram[j] << '|';
388  }
389  logger << INFO << endl;
390 
391  if (ngram.end() == find(ngram.begin(), ngram.end(), "")) {
392  // only learn ngram if it doesn't contain empty strings
393  db->incrementNgramCount(ngram);
395  logger << INFO << "Learnt ngram" << endl;
396  } else {
397  logger << INFO << "Discarded ngram" << endl;
398  }
399  }
400  }
401 
402  db->endTransaction();
403  logger << INFO << "Committed learning update to database" << endl;
404  }
406  {
408  logger << ERROR << "Rolling back learning update : " << ex.what() << endl;
409  throw;
410  }
411  }
412 
413  logger << DEBUG << "end learn()" << endl;
414 }
415 
417 {
418  // no need to begin a new transaction, as we'll be called from
419  // within an existing transaction from learn()
420 
421  // BEWARE: if the previous sentence is not true, then performance
422  // WILL suffer!
423 
424  size_t size = ngram.size();
425  for (size_t i = 0; i < size; i++) {
426  if (count(ngram, -i, size - i) > count(ngram, -(i + 1), size - (i + 1))) {
427  logger << INFO << "consistency adjustment needed!" << endl;
428 
429  int offset = -(i + 1);
430  int sub_ngram_size = size - (i + 1);
431 
432  logger << DEBUG << "i: " << i << " | offset: " << offset << " | sub_ngram_size: " << sub_ngram_size << endl;
433 
434  Ngram sub_ngram(sub_ngram_size); // need to init to right size for sub_ngram
435  copy(ngram.end() - sub_ngram_size + offset, ngram.end() + offset, sub_ngram.begin());
436 
437  if (logger.shouldLog()) {
438  logger << "ngram to be count adjusted is: ";
439  for (size_t i = 0; i < sub_ngram.size(); i++) {
440  logger << sub_ngram[i] << ' ';
441  }
442  logger << endl;
443  }
444 
445  db->incrementNgramCount(sub_ngram);
446  logger << DEBUG << "consistency adjusted" << endl;
447  }
448  }
449 }
450 
452 {
453  logger << DEBUG << "About to invoke dispatcher: " << var->get_name () << " - " << var->get_value() << endl;
454  dispatcher.dispatch (var);
455 }