import torch from torch.utils.data import DataLoader from tqdm import tqdm ADAM_LR: float = 5e-5 def label(epoch: int, loss: float) -> str: return f"Epoch={epoch} Loss={loss}" def pretrain(model, dataloader: DataLoader, device, epochs: int, save_dir: str): model.train() optimizer = torch.optim.AdamW(model.parameters(), lr=ADAM_LR) print(f"Pretraining for {epochs} epochs") for epoch in range(epochs): with tqdm(dataloader, desc=f"Epoch {epoch + 1}") as pbar: for step, batch in enumerate(pbar): batch = {k: v.to(device) for k, v in batch.items()} optimizer.zero_grad() inputs = {'input_ids': batch['input_ids'], 'attention_mask': batch['attention_mask']} outputs = model(**inputs, labels=batch['input_ids']) loss = outputs.loss loss.backward() optimizer.step() pbar.set_description(label(epoch + 1, loss.item())) model.save_pretrained(save_dir)