当存在多个出现频率相等的字符对时,Byte-pair编码会按照字符对在语料库中出现的顺序,优先将出现位置靠前的字符对合并。代码示例如下:
from collections import Counter
def get_max_freq_pair(counter):
"""
获取出现频率最高的字符对
"""
most_commons = counter.most_common()
max_freq = most_commons[0][1]
max_freq_pairs = [pair for pair, freq in most_commons if freq == max_freq]
return max_freq_pairs[0]
def merge_pair(pair, text):
"""
将字符对合并为一个新字符,并返回新文本
"""
new_text = ""
i = 0
while i < len(text):
if i < len(text) - 1 and text[i:i+2] == pair:
new_text += pair
i += 2
else:
new_text += text[i]
i += 1
return new_text
text = "aaabbb"
vocab = Counter(text)
while True:
max_freq_pair = get_max_freq_pair(vocab)
if vocab[max_freq_pair] == 1:
# 所有字符对出现频率均已为1,停止合并
break
new_char = "".join(max_freq_pair) # 合并为一个新字符
text = merge_pair(max_freq_pair, text) # 将字符对合并为新字符
vocab[new_char] = vocab[max_freq_pair] # 更新新字符出现频率
del vocab[max_freq_pair[0]], vocab[max_freq_pair[1]] # 删除旧字符对
print(text) # 输出合并后的文本