31 lines
1 KiB
Python
31 lines
1 KiB
Python
import torch
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
|
|
ADAM_LR: float = 5e-5
|
|
|
|
|
|
def label(epoch: int, loss: float) -> str:
|
|
return f"Epoch={epoch} Loss={loss}"
|
|
|
|
|
|
def pretrain(model, dataloader: DataLoader, device, epochs: int, save_dir: str):
|
|
model.train()
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=ADAM_LR)
|
|
|
|
print(f"Pretraining for {epochs} epochs")
|
|
|
|
for epoch in range(epochs):
|
|
with tqdm(dataloader, desc=f"Epoch {epoch + 1}") as pbar:
|
|
for step, batch in enumerate(pbar):
|
|
batch = {k: v.to(device) for k, v in batch.items()}
|
|
optimizer.zero_grad()
|
|
inputs = {'input_ids': batch['input_ids'], 'attention_mask': batch['attention_mask']}
|
|
outputs = model(**inputs, labels=batch['input_ids'])
|
|
loss = outputs.loss
|
|
loss.backward()
|
|
optimizer.step()
|
|
pbar.set_description(label(epoch + 1, loss.item()))
|
|
|
|
model.save_pretrained(save_dir)
|
|
|