soft-analytics-01/src/modelimpl/load.py
Claudio Maggioni 07232eddcc Final version of the bug-triaging project
Commit history has been discarded to remove large files from the repo.
2024-01-03 15:22:56 +01:00

59 lines
1.7 KiB
Python

import os
from typing import Optional
import numpy as np
import torch
from .classifier import bert_classifier, Classifier
OUT_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'out', 'model')
def get_model_path(dataset_kind: str, epochs: int, learning_rate: float, suffix_ext: str) -> str:
filename = 'bug_triaging_{0}_{1}e_{2}lr_final.{3}'.format(
dataset_kind,
str(epochs),
str(learning_rate).replace('.', '_'),
suffix_ext
)
return os.path.join(OUT_DIR, filename)
def load_model(path: str, label_range: Optional[tuple[int, int]], force_cpu: bool,
force_retrain: bool) -> tuple[bool, Classifier, int]:
if not path.endswith('.pt'):
raise ValueError("path should point to a pytorch model file")
label_range_path = path[:-3] + '.label_range.txt'
np.random.seed(0)
use_gpu = torch.cuda.is_available() and not force_cpu
if use_gpu:
print('Using device #', torch.cuda.current_device())
else:
print('CUDA is not available! Working on CPU...')
if label_range is None:
with open(label_range_path, "r") as f:
start_range = int(f.readline())
end_range = int(f.readline())
else:
start_range = label_range[0]
end_range = label_range[1]
classes = end_range - start_range
model = bert_classifier(classes)
if os.path.isfile(path) and not force_retrain:
print('Using already trained model')
if use_gpu:
model.load_state_dict(torch.load(path))
else:
model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
model.eval()
return True, model, classes
else:
return False, model, classes