"""Provides classes for handling Shakespeare datasets.This module contains two main classes:1. Shakespeare: For handling tokenized Shakespeare text using BPE tokenization.2. ShakespeareChar: For handling character-level Shakespeare text.Both classes provide methods for downloading, tokenizing, encoding, and decodingShakespeare's text.Typical usage example: shakespeare = Shakespeare(1024) char_shakespeare = ShakespeareChar()"""importpicklefromcollectionsimportabcfrompathlibimportPathimportnumpyasnpimportrequestsfromtricycle.tokeniserimportBPETokeniser
[docs]classShakespeare(abc.Sequence):"""A class for handling tokenized Shakespeare text using BPE tokenization. This class downloads the Shakespeare dataset, tokenizes it using BPE, and provides methods for encoding and decoding text. Attributes: url: A string containing the URL for the Shakespeare dataset. vocab_size: An integer representing the size of the vocabulary. token_path: A Path object for the tokenized data file. raw_data_path: A Path object for the raw data file. tokens: A numpy array containing the tokenized data. tokeniser: A BPETokeniser object for tokenization. """url:str=("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"# noqa: E501)vocab_size:inttoken_path:Pathraw_data_path:Pathtokens:np.ndarraydef__init__(self,vocab_size:int,token_path:Path|None=None,raw_data_path:Path=Path("datasets/shakespeare/raw_data.txt"),tokeniser_path:Path=Path("datasets/shakespeare/tokeniser.pkl"),):"""Initializes the Shakespeare object. Args: vocab_size: An integer representing the size of the vocabulary. token_path: A Path object for the tokenized data file. If None, a default path is used. raw_data_path: A Path object for the raw data file. tokeniser_path: A Path object for the tokeniser pickle file. """iftoken_pathisNone:token_path=Path(f"datasets/shakespeare/tokens_{vocab_size}.pkl")self.vocab_size=vocab_sizeself.raw_data_path=raw_data_pathself.token_path=token_pathself.tokeniser_path=tokeniser_pathifself.tokeniser_path.exists():withopen(self.tokeniser_path,"rb")asf:self.tokeniser=pickle.load(f)else:self.tokeniser=Noneifnotself.token_path.exists():self.tokeniser=self.generate()self.tokeniser_path.parent.mkdir(parents=True,exist_ok=True)withopen(self.tokeniser_path,"wb")asf:pickle.dump(self.tokeniser,f)self.tokens=self.tokeniser.tokensself.token_path.parent.mkdir(parents=True,exist_ok=True)withopen(self.token_path,"wb")asf:pickle.dump(self.tokens,f)else:withopen(self.token_path,"rb")asf:self.tokens=pickle.load(f)
[docs]defdownload(self):"""Downloads the Shakespeare dataset. The downloaded data is saved to the path specified by raw_data_path. """raw_data=requests.get(self.url).textself.raw_data_path.parent.mkdir(parents=True,exist_ok=True)withopen(self.raw_data_path,"w")asf:f.write(raw_data)
[docs]defgenerate(self)->BPETokeniser:"""Downloads and tokenizes the Shakespeare dataset. Returns: A BPETokeniser object trained on the Shakespeare dataset. """self.download()raw_data=np.array(list(self.raw_data_path.read_bytes()),dtype=np.int32)ifself.tokeniserisNone:self.tokeniser=BPETokeniser(self.vocab_size)returnself.tokeniser.train_ints(raw_data,loading_bar=True)
def__getitem__(self,idx:int)->int|list[int]:"""Returns the token(s) at the specified index. Args: idx: An integer index or slice. Returns: The token(s) at the specified index. """returnself.tokens[idx]def__len__(self)->int:"""Returns the number of tokens in the dataset. Returns: An integer representing the number of tokens. """returnlen(self.tokens)
[docs]defencode(self,*args):"""Encodes the input using the BPE tokenizer. Args: *args: Arguments to pass to the tokenizer's encode method. Returns: The encoded input. """returnself.tokeniser.encode(*args)
[docs]defdecode(self,*args):"""Decodes the input using the BPE tokenizer. Args: *args: Arguments to pass to the tokenizer's decode method. Returns: The decoded input. """returnself.tokeniser.decode(*args)
[docs]classShakespeareChar(abc.Sequence):"""A class for handling character-level Shakespeare text. This class downloads the Shakespeare dataset and provides methods for encoding and decoding text at the character level. Attributes: url: A string containing the URL for the Shakespeare dataset. vocab_size: An integer representing the size of the vocabulary. raw_data_path: A Path object for the raw data file. chars: A list of integers representing the characters in the dataset. """url:str=("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"# noqa: E501)vocab_size:intraw_data_path:Pathchars:list[int]def__init__(self,raw_data_path:Path=Path("datasets/shakespeare/raw_data.txt"),):"""Initializes the ShakespeareChar object. Args: raw_data_path: A Path object for the raw data file. """self.raw_data_path=raw_data_pathself.chars=self.generate()self.vocab_size=len(set(self.chars))
[docs]defencode(self,chars:list[int]|str):"""Encodes the input characters into character IDs. Args: chars: A list of integers or a string to encode. Returns: A list of integer character IDs. """ifisinstance(chars,str):chars=[ord(i)foriinchars]return[self.char_ids[c]forcinchars]
[docs]defdecode(self,char_ids:list[int]):"""Decodes the input character IDs into characters. Args: char_ids: A list of integer character IDs to decode. Returns: A list of decoded characters. """inv_char_ids={i:cforc,iinself.char_ids.items()}return[inv_char_ids[i]foriinchar_ids]
[docs]defdownload(self):"""Downloads the Shakespeare dataset. The downloaded data is saved to the path specified by raw_data_path. """raw_data=requests.get(self.url).textself.raw_data_path.parent.mkdir(parents=True,exist_ok=True)withopen(self.raw_data_path,"w")asf:f.write(raw_data)
[docs]defgenerate(self)->list[int]:"""Downloads and processes the Shakespeare dataset. Returns: A list of integers representing the characters in the dataset. """ifnotself.raw_data_path.exists():self.download()raw_data=list(self.raw_data_path.read_bytes())self.char_ids={c:ifori,cinenumerate(set(raw_data))}returnself.encode(raw_data)
def__getitem__(self,idx:int)->int|list[int]:"""Returns the character(s) at the specified index. Args: idx: An integer index or slice. Returns: The character(s) at the specified index. """returnself.chars[idx]def__len__(self)->int:"""Returns the number of characters in the dataset. Returns: An integer representing the number of characters. """returnlen(self.chars)