59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
import os
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from .classifier import bert_classifier, Classifier
|
|
|
|
OUT_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'out', 'model')
|
|
|
|
|
|
def get_model_path(dataset_kind: str, epochs: int, learning_rate: float, suffix_ext: str) -> str:
|
|
filename = 'bug_triaging_{0}_{1}e_{2}lr_final.{3}'.format(
|
|
dataset_kind,
|
|
str(epochs),
|
|
str(learning_rate).replace('.', '_'),
|
|
suffix_ext
|
|
)
|
|
return os.path.join(OUT_DIR, filename)
|
|
|
|
|
|
def load_model(path: str, label_range: Optional[tuple[int, int]], force_cpu: bool,
|
|
force_retrain: bool) -> tuple[bool, Classifier, int]:
|
|
if not path.endswith('.pt'):
|
|
raise ValueError("path should point to a pytorch model file")
|
|
|
|
label_range_path = path[:-3] + '.label_range.txt'
|
|
|
|
np.random.seed(0)
|
|
|
|
use_gpu = torch.cuda.is_available() and not force_cpu
|
|
|
|
if use_gpu:
|
|
print('Using device #', torch.cuda.current_device())
|
|
else:
|
|
print('CUDA is not available! Working on CPU...')
|
|
|
|
if label_range is None:
|
|
with open(label_range_path, "r") as f:
|
|
start_range = int(f.readline())
|
|
end_range = int(f.readline())
|
|
else:
|
|
start_range = label_range[0]
|
|
end_range = label_range[1]
|
|
|
|
classes = end_range - start_range
|
|
model = bert_classifier(classes)
|
|
|
|
if os.path.isfile(path) and not force_retrain:
|
|
print('Using already trained model')
|
|
if use_gpu:
|
|
model.load_state_dict(torch.load(path))
|
|
else:
|
|
model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
|
|
model.eval()
|
|
return True, model, classes
|
|
else:
|
|
return False, model, classes
|