Byte Pair Encoding(BPE)算法及代码笔记

2024-01-29 02:28

Byte Pair Encoding(BPE)算法


  1. 将语料中的文本切分为字符
  2. 统计高频共现二元组
  3. 将共现频率最高的二元组合并加入词表
  4. 重复上述第二和第三直到词表规模达到预先设置的数量,或没有可以合并的二元组为止



BPE算法:字节对编码算法,将任意UTF-8字符串转换为整数索引序列,方便后续的神经网络运算。bpe is short for Byte Pair Encoder. It translates arbitrary utf-8 strings into
sequences of integers, where each integer represents small chunks of commonly
occuring characters. This implementation is based on openai's gpt2
but was mildly modified because the original implementation is a bit confusing.
I also tried to add as many comments as possible, my own understanding of what's
going on.
"""import os
import json
import regex as re
import requestsimport torch# -----------------------------------------------------------------------------def bytes_to_unicode():"""将字节(8bit->2**8->256个)转换为unicode表示的字符。有些字节表示的字符太"丑"了,比如chr(0)为'\x00',OpenAI选择进行额外的转换。Every possible byte (really an integer 0..255) gets mapped by OpenAI to a unicodecharacter that represents it visually. Some bytes have their appearance preservedbecause they don't cause any trouble. These are defined in list bs. For example:chr(33) returns "!", so in the returned dictionary we simply have d[33] -> "!".However, chr(0), for example, is '\x00', which looks ugly. So OpenAI maps thesebytes, into new characters in a range where chr() returns a single nice character.So in the final dictionary we have d[0] -> 'Ā' instead, which is just chr(0 + 2**8).In particular, the space character is 32, which we can see by ord(' '). Instead,this function will shift space (32) by 256 to 288, so d[32] -> 'Ġ'.So this is just a simple one-to-one mapping of bytes 0..255 into unicode charactersthat "look nice", either in their original form, or a funny shifted characterlike 'Ā', or 'Ġ', etc."""# the 188 integers that render fine in their original form and need no shiftingbs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))cs = bs[:] # all integers b in bs will simply map to chr(b) in the output dict# now get the representations of the other 68 integers that do need shifting# each will get mapped chr(256 + n), where n will grow from 0...67 in the loopn = 0for b in range(2**8):if b not in bs:# if this byte is "ugly" then map it to the next available "nice" characterbs.append(b)cs.append(2**8+n)n += 1cs = [chr(n) for n in cs]d = dict(zip(bs, cs))return ddef get_pairs(word):"""获取一个单词中所有可能的字符二元组Return all bigrams as a set of tuples, of consecutive elements in the iterable word."""pairs = set()prev_char = word[0]for char in word[1:]:pairs.add((prev_char, char))prev_char = charreturn pairsclass Encoder:def __init__(self, encoder, bpe_merges):# byte encoder/decoderself.byte_encoder = bytes_to_unicode()self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}# bpe token encoder/decoderself.encoder = encoder  # 将字符串转换为整数索引self.decoder = {v:k for k,v in self.encoder.items()}  # 将整数索引转换为字符串# bpe merge list that defines the bpe "tree", of tuples (a,b) that are to merge to token abself.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))# the splitting pattern used for pre-tokenization# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions <-- original openai comment"""ok so what is this regex looking for, exactly?python re reference: the vertical bars | is OR, so re.findall will chunkate text as the pieces match, from left to right- '\'s' would split up things like Andrej's -> (Andrej, 's)- ' ?\p{L}': optional space followed by 1+ unicode code points in the category "letter"- ' ?\p{N}': optional space followed by 1+ unicode code points in the category "number"- ' ?[^\s\p{L}\p{N}]+': optional space, then 1+ things that are NOT a whitespace, letter or number- '\s+(?!\S)': 1+ whitespace characters (e.g. space or tab or etc) UNLESS they are followed by non-whitespaceso this will consume whitespace characters in a sequence but exclude the last whitespace inthat sequence. that last whitespace has the opportunity to then match the optional ' ?' inearlier patterns.- '\s+': 1+ whitespace characters, intended probably to catch a full trailing sequence of whitespaces at end of stringSo TLDR:- we are special casing a few common apostrophe constructs ('s, 't, 're, ...) and making those into separate tokens- we then separate out strings into consecutive chunks of 1) letters, 2) numbers, 3) non-letter-numbers, 4) whitespaces"""self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")  # 预先使用一些正则表达式提前将字符串切分,例如将字符串划分为连续的字母、数字、空格和其他字符。包括一些英文的规则。self.cache = {}def bpe(self, token):"""对每个预先切分出来的token进行进一步的bpe切分,切分主要依赖于预先统计的bpe_ranks;bpe_ranks: 从大规模语料中统计的bi-gram共现频率this function uses self.bpe_ranks to iteratively merge all the possible bpe tokensup the tree. token is a string of one individual 'word' (after regex tokenization)and after byte encoding, e.g. 'Ġthere'."""# token is a string of one individual 'word', after byte encoding, e.g. 'Ġthere'# memoization, for efficiencyif token in self.cache:  # cache缓存加速bpe算法return self.cache[token]word = tuple(token) # individual characters that make up the token, in a tuplepairs = get_pairs(word) # get all bigramsif not pairs:return tokenwhile True:# find the next lowest rank bigram that can be mergedbigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))  # 优先合并共现频率高的二元组if bigram not in self.bpe_ranks:  # 如果剩下的二元组共现频率过低break # no more bigrams are eligible to be mergedfirst, second = bigram# we will now replace all occurences of (first, second) in the list of current# words into one merged token first_second, in the output list new_wordsnew_word = []i = 0while i < len(word):  # 合并二元组(考虑多次出现的情况)# find the next occurence of first in the sequence of current wordstry:j = word.index(first, i)new_word.extend(word[i:j])i = jexcept:new_word.extend(word[i:])break# if this occurence is also followed by second, then merge them into oneif word[i] == first and i < len(word)-1 and word[i+1] == second:new_word.append(first+second)i += 2else:new_word.append(word[i])i += 1# all occurences of (first, second) have been merged to first_secondnew_word = tuple(new_word)word = new_wordif len(word) == 1:breakelse:pairs = get_pairs(word)# concat all words into a string, and use ' ' as the separator. Note that# by now all characters have been byte encoded, guaranteeing that ' ' is# not used in the actual data and is a 'special' delimiter characterword = ' '.join(word)# cache the result and returnself.cache[token] = wordreturn worddef encode(self, text):""" 字符串序列转整数索引序列string goes in, list of integers comes out"""bpe_idx = []# pre-tokenize the input text into string tokens (words, roughly speaking)tokens = re.findall(self.pat, text)  # 预先使用正则表达式粗糙切分# process each token into BPE integersfor token in tokens:  # 每个token内部使用bpe不断合并二元组# encode the token as a bytes (b'') objecttoken_bytes = token.encode('utf-8')# translate all bytes to their unicode string representation and flattentoken_translated = ''.join(self.byte_encoder[b] for b in token_bytes)# perform all the applicable bpe merges according to self.bpe_rankstoken_merged = self.bpe(token_translated).split(' ')# translate all bpe tokens to integerstoken_ix = [self.encoder[bpe_token] for bpe_token in token_merged]# extend our running list of all output integersbpe_idx.extend(token_ix)return bpe_idxdef encode_and_show_work(self, text):""" debugging function, same as encode but returns all intermediate work """bpe_idx = []parts = []tokens = re.findall(self.pat, text)for token in tokens:token_bytes = token.encode('utf-8')token_translated = ''.join(self.byte_encoder[b] for b in token_bytes)token_merged = self.bpe(token_translated).split(' ')token_ix = [self.encoder[bpe_token] for bpe_token in token_merged]bpe_idx.extend(token_ix)parts.append({'token': token,'token_bytes': token_bytes,'token_translated': token_translated,'token_merged': token_merged,'token_ix': token_ix,})out = {'bpe_idx': bpe_idx, # the actual output sequence'tokens': tokens, # result of pre-tokenization'parts': parts, # intermediates for each token part}return outdef decode(self, bpe_idx):""" 整数索引序列恢复成字符串序列list of integers comes in, string comes out """# inverse map the integers to get the tokenstokens_merged = [self.decoder[token] for token in bpe_idx]# inverse the byte encoder, e.g. recovering 'Ġ' -> ' ', and get the bytestokens_flat = ''.join(tokens_merged)tokens_bytes = bytearray([self.byte_decoder[c] for c in tokens_flat])# recover the full utf-8 stringtext = tokens_bytes.decode('utf-8', errors='replace')return textdef get_file(local_file, remote_file):""" downloads remote_file to local_file if necessary """if not os.path.isfile(local_file):print(f"downloading {remote_file} to {local_file}")response = requests.get(remote_file)open(local_file, "wb").write(response.content)def get_encoder():"""从OpenAI官方的GPT-2分词器cache文件初始化Returns an instance of the GPT BPE Encoder/Decoderand handles caching of "database" files."""home_dir = os.path.expanduser('~')cache_dir = os.path.join(home_dir, '.cache', 'mingpt')os.makedirs(cache_dir, exist_ok=True)# load encoder.json that has the raw mappings from token -> bpe indexencoder_local_file = os.path.join(cache_dir, 'encoder.json')encoder_remote_file = ''get_file(encoder_local_file, encoder_remote_file)with open(encoder_local_file, 'r') as f:encoder = json.load(f)assert len(encoder) == 50257 # 256 individual byte tokens, 50,000 merged tokens, and 1 special <|endoftext|> token# load vocab.bpe that contains the bpe merges, i.e. the bpe tree structure# in the form tuples (a, b), that indicate that (a, b) is to be merged to one token abvocab_local_file = os.path.join(cache_dir, 'vocab.bpe')vocab_remote_file = ''get_file(vocab_local_file, vocab_remote_file)with open(vocab_local_file, 'r', encoding="utf-8") as f:bpe_data = light postprocessing: strip the version on first line and the last line is a blankbpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]assert len(bpe_merges) == 50000 # 50,000 merged tokens# construct the Encoder object and returnenc = Encoder(encoder, bpe_merges)return enc# -----------------------------------------------------------------------------class BPETokenizer:""" PyTorch-aware class that wraps the Encoder above """def __init__(self):self.encoder = get_encoder()def __call__(self, text, return_tensors='pt'):# PyTorch only; here because we want to match huggingface/transformers interfaceassert return_tensors == 'pt'# single string input for now, in the future potentially a list of stringsassert isinstance(text, str)# encode and create a "batch dimension" of 1idx = [self.encoder.encode(text)]# wrap into PyTorch tensorout = torch.tensor(idx, dtype=torch.long)return outdef decode(self, idx):# ensure a simple 1D tensor for nowassert idx.ndim == 1# decode indices to texttext = self.encoder.decode(idx.tolist())return text


def bpe(self, token):# cache缓存加速bpe算法if token in self.cache:  return self.cache[token]word = tuple(token) # individual characters that make up the token, in a tuplepairs = get_pairs(word) # get all bigramsif not pairs:return tokenwhile True:# find the next lowest rank bigram that can be mergedbigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))  # 优先合并共现频率高的二元组if bigram not in self.bpe_ranks:  # 如果剩下的二元组共现频率过低break # no more bigrams are eligible to be mergedfirst, second = bigram# we will now replace all occurences of (first, second) in the list of current# words into one merged token first_second, in the output list new_wordsnew_word = []i = 0while i < len(word):  # 合并二元组(考虑多次出现的情况)# find the next occurence of first in the sequence of current wordstry:j = word.index(first, i)new_word.extend(word[i:j])i = jexcept:new_word.extend(word[i:])break# if this occurence is also followed by second, then merge them into oneif word[i] == first and i < len(word)-1 and word[i+1] == second:new_word.append(first+second)i += 2else:new_word.append(word[i])i += 1# all occurences of (first, second) have been merged to first_secondnew_word = tuple(new_word)word = new_wordif len(word) == 1:breakelse:pairs = get_pairs(word)# concat all words into a string, and use ' ' as the separator. Note that# by now all characters have been byte encoded, guaranteeing that ' ' is# not used in the actual data and is a 'special' delimiter characterword = ' '.join(word)# cache the result and returnself.cache[token] = wordreturn word


# cache缓存加速bpe算法
if token in self.cache:  return self.cache[token]
word = tuple(token) # individual characters that make up the token, in a tuple
pairs = get_pairs(word) # get all bigrams
def get_pairs(word):pairs = set()prev_char = word[0]for char in word[1:]:pairs.add((prev_char, char))prev_char = charreturn pairs
if not pairs:return token
# find the next lowest rank bigram that can be merged
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))  # 优先合并共现频率高的二元组       
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
vocab_local_file = os.path.join(cache_dir, 'vocab.bpe')
vocab_remote_file = ''
get_file(vocab_local_file, vocab_remote_file)
with open(vocab_local_file, 'r', encoding="utf-8") as f:bpe_data =
# light postprocessing: strip the version on first line and the last line is a blank
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
if bigram not in self.bpe_ranks:  # 如果剩下的二元组共现频率过低break # no more bigrams are eligible to be merged
first, second = bigram
# we will now replace all occurences of (first, second) in the list of current
# words into one merged token first_second, in the output list new_words
new_word = []
i = 0
while i < len(word):  # 合并二元组(考虑多次出现的情况)# find the next occurence of first in the sequence of current wordstry:j = word.index(first, i)new_word.extend(word[i:j])i = jexcept:new_word.extend(word[i:])break# if this occurence is also followed by second, then merge them into oneif word[i] == first and i < len(word)-1 and word[i+1] == second:new_word.append(first+second)i += 2else:new_word.append(word[i])i += 1
# all occurences of (first, second) have been merged to first_second
new_word = tuple(new_word)
word = new_word
if len(word) == 1:break
else:pairs = get_pairs(word)
word = ' '.join(word)# cache the result and return
self.cache[token] = word


