#import patch_pytorch # for rasbpi testing, from wip branch of github/xloem/mempickle import math thousands_names = ' thousand million billion'.split(' ') numeral_names = 'zero one two three four five six seven eight nine'.split(' ') tens_names = 'zero ten twenty thirty forty fifty sixty seventy eighty ninety'.split(' ') teens_names = 'ten eleven twelve thirteen fourteen fifteen sixteen seventeen eighteen nineteen'.split(' ') # can we convert between words and numbers def number_to_word(num): num = int(num) if num == 0: return 'zero' result = '' prefix = '' suffix = '' if num < 0: prefix += 'negative ' num = -num places = int(math.log10(num)) + 1 for digit in range(0, places, 3): value = num % 1000 num //= 1000 if value == 0: continue hundred = value // 100 ten = (value % 100) // 10 one = value % 10 part = '' if hundred > 0: part += numeral_names[hundred] + ' hundred' if ten == 1: if len(part): part += ' ' part += teens_names[one] else: if ten > 0: if len(part): part += ' ' part += tens_names[ten] if one > 0: if len(part): part += ' ' part += numeral_names[one] if digit > 0 and len(part): part += ' ' + thousands_names[digit // 3] if len(suffix): part += ' ' suffix = part + suffix return prefix + suffix import transformers, torch class Model(transformers.PerceiverPreTrainedModel): def __init__(self, config): super().__init__(config) self.input_preprocessor = transformers.models.perceiver.modeling_perceiver.PerceiverTextPreprocessor(config) self.decoder = transformers.models.perceiver.modeling_perceiver.PerceiverBasicDecoder( config, output_num_channels = config.d_latents, output_index_dims = config.max_position_embeddings, num_channels = config.d_model, qk_channels = config.qk_channels, v_channels = config.d_model, num_heads = config.num_decoder_heads, use_query_residual = False, final_project = False, trainable_position_encoding_kwargs = dict( num_channels = self.input_preprocessor.num_channels, index_dims = config.max_position_embeddings ), ) self.perceiver = transformers.PerceiverModel( config, decoder = self.decoder, input_preprocessor = self.input_preprocessor, ) self.output_postprocessor = transformers.models.perceiver.modeling_perceiver.PerceiverEmbeddingDecoder(config) self.post_init() def forward(self, inputs=None, attention_mask=None, head_mask=None, output_attentions=None, output_hidden_states=None, labels=None):#, return_dict=None, input_ids=None): outputs = self.perceiver( inputs=inputs, attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=False,#return_dict, ) logits = self.output_postprocessor( #outputs.logits if return_dict else outputs[0], embedding_layer=self.perceiver.input_preprocessor.embeddings outputs[0], embedding_layer=self.perceiver.input_preprocessor.embeddings ) loss = None if labels is not None: loss = torch.nn.CrossEntropyLoss()(logits.view(-1, self.config.vocab_size), labels.view(-1)) output = (logits,) + outputs[1:] # outputs[2:] if loss is None: return output else: return ((loss,) + output) config = transformers.PerceiverConfig() config.num_decoder_heads = config.num_cross_attention_heads config.num_self_attends_per_block = 6 config.max_position_embeddings = 64 config.d_model = 128 config.d_latents = 256 config.vocab_size = 256 print('Constructing model ...', flush=True) model = Model(config) ## maybe: per-process vmem for low-end systems; https://github.com/xloem/mempickle #import pytorch_tensormap #mmap_params = pytorch_tensormap.PyTorchMap() #mmap_params.write(model.state_dict()) #model.load_state_dict(mmap_params.read(writeable = True)) import torch def numbers_to_numword_tensors(numbers, batch_size): maxlen = config.max_position_embeddings words_tensor = torch.empty((len(numbers), maxlen), dtype=torch.int8) numbers_tensor = torch.empty((len(numbers), maxlen), dtype=torch.int8) for idx, number in enumerate(numbers): word = number_to_word(number) + '.' # b'\x9c' is -100, the label mask words_tensor[idx] = torch.frombuffer(word.encode('iso-8859-1').ljust(maxlen, b'\x9c'), dtype=torch.int8) number = str(int(number)) + '.' numbers_tensor[idx] = torch.frombuffer(number.encode('iso-8859-1').ljust(maxlen, b'\x9c'), dtype=torch.int8) return ( numbers_tensor.view(len(numbers) // batch_size, batch_size, maxlen), words_tensor.view(len(numbers) // batch_size, batch_size, maxlen) ) total = 100000 batch_size = 16 #total // 256 #tt_split = batch_size #len(data) // 16 total = total - (total % batch_size) print('Generating data ...', flush=True) all_numbers, all_words = numbers_to_numword_tensors(torch.randperm(total), batch_size) #all_numbers = all_numbers.to(torch.long) #all_words = all_words.to(torch.long) train_numbers = all_numbers[1:] test_numbers = all_numbers[:1] train_words = all_words[1:] test_words = all_words[:1] constant_word_labels = torch.stack([train_words[0][0] for x in range(batch_size)]).to(torch.long) constant_number_labels = torch.stack([train_numbers[0][0] for x in range(batch_size)]).to(torch.long) # so on one end of the model, we take or output the number # on the other end, we output or take the word print('Starting training ...', flush=True) cuda = True if cuda: model.cuda() constant_number_labels = constant_number_labels.cuda() constant_word_labels = constant_word_labels.cuda() model.train() optim = torch.optim.SGD(model.parameters(), lr=0.0002) for idx, (number_batch, word_batch) in enumerate(zip(train_numbers, train_words)): optim.zero_grad() number_data = number_batch.to(torch.long) if cuda: number_data = number_data.cuda() number_mask = (number_data != -100).to(torch.float32) number_data[number_data == -100] = 32 word_data = word_batch.to(torch.long) if cuda: word_data = word_data.cuda() word_mask = (word_data != -100).to(torch.float32) word_data[word_data == -100] = 32 #labels = number_data.clone() # train_numbers[0,0,0] #labels[number_mask == 0] = -100 labels = constant_number_labels inputs = word_data attention_mask = word_mask loss, logits, output = model(inputs=inputs, attention_mask=attention_mask, labels=labels) loss.backward() observation = inputs[0].detach().to(torch.uint8).cpu().numpy().tobytes() observation = observation[:observation.find(b'.')] guess = logits[0].detach().argmax(dim=1).to(torch.uint8).cpu().numpy().tobytes() guess = guess[:guess.find(b'.')] print(f'{idx} {loss} {observation} -> {guess} ', flush=True)#, end='\r') optim.step()