diff --git a/local/cn2lm/ngramlm/cn2lm_trig.py b/local/cn2lm/ngramlm/cn2lm_trig.py index f9c623b01cd22648acca980a6df904c01419a3e4..e7ba7125dbb3653fc968fce660086bcd579e54b6 100644 --- a/local/cn2lm/ngramlm/cn2lm_trig.py +++ b/local/cn2lm/ngramlm/cn2lm_trig.py @@ -219,34 +219,41 @@ class CN2LM: def compute_ug_probs(self): # as per (49) of https://www.microsoft.com/en-us/research/uploads/prod/2018/06/ExpectedKneserNey-Tech-Report.pdf ug_probs = {} - disc_ug_probs = numpy.zeros(self.N) # unigram discounted probs + disc_ug_probs = {} # unigram discounted probs ug_E_cnt_sum = 0.0 redistribute_cnt = 0 - for key in self.ug_E_cnt.keys(): - j = int(key) - if self.ug_E_cnt[key] < 0.0: # double check - print('Fatal Error: ug_E_cnt < 0 for ' + self.vocab_list[j]) - disc_ug_probs[j] = 0.0 + for int_key in range(self.N): + key = str(int_key) + if self.vocab_list[int_key] in UG_SKIP_SYM_LIST: + disc_ug_probs[key] = 0.0 # no prob, noredistribute_cnt + elif key not in self.ug_E_cnt.keys(): + disc_ug_probs[key] = 0.0 + redistribute_cnt += 1 + elif self.ug_E_cnt[key] < 0.0: # double check + print('Fatal Error: ug_E_cnt < 0 for ' + self.vocab_list[int_key]) + disc_ug_probs[key] = 0.0 redistribute_cnt += 1 - elif self.vocab_list[j] in UG_SKIP_SYM_LIST: - disc_ug_probs[j] = 0.0 elif self.ug_E_cnt[key] == NO_E_CNT: - disc_ug_probs[j] = 0.0 + disc_ug_probs[key] = 0.0 redistribute_cnt += 1 else: ug_E_cnt_sum += self.ug_E_cnt[key] - disc_ug_probs[j] = (self.ug_E_cnt[key] - self.ug_DP[key]) # only numerator for discounted probs + disc_ug_probs[key] = (self.ug_E_cnt[key] - self.ug_DP[key]) # only numerator for discounted probs redistribute_cnt += 1 - - disc_ug_probs = disc_ug_probs / ug_E_cnt_sum # coversion to discounted probs - left_over_prob_mass = 1.0 - numpy.sum(disc_ug_probs) + + psum = 0.0 + for k in disc_ug_probs.keys(): + disc_ug_probs[k] = disc_ug_probs[k] / ug_E_cnt_sum # coversion to discounted probs + psum += disc_ug_probs[k] + left_over_prob_mass = 1.0 - psum redistribute_prob = left_over_prob_mass / redistribute_cnt # redistribute uniformly to all - for j in range(self.N): - if self.vocab_list[j] in UG_SKIP_SYM_LIST: - if self.vocab_list[j] == '<s>': - ug_probs[self.vocab_list[j]] = 0.0 - else: - ug_probs[self.vocab_list[j]] = disc_ug_probs[j] + redistribute_prob + + for key in disc_ug_probs.keys(): + int_key = int(key) + if self.vocab_list[int_key] == '<s>': + ug_probs[self.vocab_list[int_key]] = 0.0 + elif self.vocab_list[int_key] not in UG_SKIP_SYM_LIST: + ug_probs[self.vocab_list[int_key]] = disc_ug_probs[key] + redistribute_prob return ug_probs def compute_ug_bow(self): # as per (55) of https://www.microsoft.com/en-us/research/uploads/prod/2018/06/ExpectedKneserNey-Tech-Report.pdf