This repository has been archived on 2024-10-22. You can view files and clone it, but cannot push or open issues or pull requests.
soft-analytics-02/train/pretrain.py
Claudio Maggioni a4ceee8716 Final version of the project
History has been rewritten to delete large files in repo
2024-01-03 15:28:43 +01:00

31 lines
1 KiB
Python

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
ADAM_LR: float = 5e-5
def label(epoch: int, loss: float) -> str:
return f"Epoch={epoch} Loss={loss}"
def pretrain(model, dataloader: DataLoader, device, epochs: int, save_dir: str):
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=ADAM_LR)
print(f"Pretraining for {epochs} epochs")
for epoch in range(epochs):
with tqdm(dataloader, desc=f"Epoch {epoch + 1}") as pbar:
for step, batch in enumerate(pbar):
batch = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()
inputs = {'input_ids': batch['input_ids'], 'attention_mask': batch['attention_mask']}
outputs = model(**inputs, labels=batch['input_ids'])
loss = outputs.loss
loss.backward()
optimizer.step()
pbar.set_description(label(epoch + 1, loss.item()))
model.save_pretrained(save_dir)