import torch import torch.nn as nn import torch.optim as optim from torchtext import data from gensim.corpora import WikiCorpus from transformers import GPT2Tokenizer, GPT2Model from functions import * # Define the model # class GPT(nn.Module): # def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers): # super().__init__() # self.embedding = nn.Embedding(vocab_size, embedding_dim) # self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True) # self.fc = nn.Linear(hidden_dim, vocab_size) # self.gpt2 = model # def forward(self, x): # # Embed the input # x = self.embedding(x) # # Pass through the GPT2 model # x = self.gpt2(x) # # Pass through the LSTM # x, _ = self.lstm(x) # # Pass through the fully connected layer # x = self.fc(x) # return x # Load the GPT2 model print('load gpt2 model') tokenizer = GPT2Tokenizer.from_pretrained('gpt2') model = GPT2Model.from_pretrained('gpt2') # Load the data print('load custom data') # wiki_corpus_en = WikiCorpus('data/enwiki-latest-pages-articles.xml.bz2') wiki_corpus_fr = WikiCorpus('data/frwiki-latest-pages-articles.xml.bz2') # stackoverflow_corpus = data.TabularDataset('data/stackoverflow.csv', format='csv', fields=['text']) # Preprocess the data print('Preprocess the data') # wiki_data_en = [text for text in wiki_corpus_en] wiki_data_fr = [text for text in wiki_corpus_fr] # stackoverflow_data = [text for text in stackoverflow_corpus] # Convert the data to a format compatible with PyTorch print('Convert the data to a format compatible with PyTorch') # wiki_data_en = torch.tensor(wiki_data_en) wiki_data_fr = torch.tensor(wiki_data_fr) # stackoverflow_data = torch.tensor(stackoverflow_data) # Define the Adam optimizer print('Define the Adam optimizer') optimizer = optim.Adam(model.parameters(), lr=0.001) # Define the loss function print('Define the loss function') criterion = nn.CrossEntropyLoss() # Train the model print('Train the model') num_epochs=10 labels = torch.tensor([0, 1, 1, 0, 0, 1, 0, 1, 0, 1]) for epoch in range(num_epochs): print('epoch: ' + epoch) # Forward pass # outputs = model(wiki_data, stackoverflow_data) outputs = model(wiki_data_fr) # Calculate the loss loss = criterion(outputs, labels) # Backward pass loss.backward() # Update the parameters optimizer.step() # Reset the gradients optimizer.zero_grad() # Evaluate the model accuracy = evaluate(model, wiki_data_fr) # Save the model weights and states torch.save(model.state_dict(), 'model.pth') # Adjust the learning rate adjust_learning_rate(optimizer, epoch) # Print the loss and accuracy print('Epoch: {}, Loss: {:.4f}, Accuracy: {:.4f}'.format(epoch+1, loss.item(), accuracy))