diff --git a/run.py b/run.py index f963d58..b879e9e 100644 --- a/run.py +++ b/run.py @@ -14,7 +14,7 @@ def run(show_plots=False): if show_plots: instance.plot_data() - for method in ["random", "nearest_neighbors"]: + for method in ["random", "nearest_neighbors", "best_nn"]: solver = Solver_TSP(method) solver(instance, return_value=False) diff --git a/src/TSP_solver.py b/src/TSP_solver.py index 7362ffa..2cb68d0 100644 --- a/src/TSP_solver.py +++ b/src/TSP_solver.py @@ -41,10 +41,12 @@ class Solver_TSP: n = int(instance_.nPoints) node = np.argmin([starting_node]) tour = [node] - for _ in range(n - 2): - for node in np.argsort(dist_matrix[node]): - if node not in tour: - tour.append(node) + for _ in range(n - 1): + for new_node in np.argsort(dist_matrix[node]): + if new_node not in tour: + tour.append(new_node) + node = new_node + break tour.append(starting_node) self.solution = np.array(tour) self.solved = True