Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
127 views
in Technique[技术] by (71.8m points)

java - What's wrong with my implementation of the nearest neighbour algorithm (for the TSP)?

I was tasked with implementing the nearest neighbour algorithm for the travelling salesman problem. It was said that the method should try starting from every city and return the best tour found. According to the auto-marking program, my implementation works correctly for the most basic case, but only works partially for all more advanced cases.

I don't understand where I went wrong, and am seeking a review of my code for correctness. I am keen to find out where I went wrong and what the correct approach would be.

My Java code is as follows:

/*
 * Returns the shortest tour found by exercising the NN algorithm 
 * from each possible starting city in table.
 * table[i][j] == table[j][i] gives the cost of travel between City i and City j.
 */
 public static int[] tspnn(double[][] table) {
     
     // number of vertices 
     int numberOfVertices = table.length;
     // the Hamiltonian cycle built starting from vertex i
     int[] currentHamiltonianCycle = new int[numberOfVertices];
     // the lowest total cost Hamiltonian cycle
     double lowestTotalCost = Double.POSITIVE_INFINITY;
     //  the shortest Hamiltonian cycle
     int[] shortestHamiltonianCycle = new int[numberOfVertices];
     
     // consider each vertex i as a starting point
     for (int i = 0; i < numberOfVertices; i++) {
         /* 
          * Consider all vertices that are reachable from the starting point i,
          * thereby creating a new current Hamiltonian cycle.
          */
         for (int j = 0; j < numberOfVertices; j++) {
             /* 
              * The modulo of the sum of i and j allows us to account for the fact 
              * that Java indexes arrays from 0.
              */
             currentHamiltonianCycle[j] = (i + j) % numberOfVertices;   
         }
         for (int j = 1; j < numberOfVertices - 1; j++) {
             int nextVertex = j;
             for (int p = j + 1; p < numberOfVertices; p++) {
                 if (table[currentHamiltonianCycle[j - 1]][currentHamiltonianCycle[p]] < table[currentHamiltonianCycle[j - 1]][currentHamiltonianCycle[nextVertex]]) {
                           nextVertex = p;
                 }
             }
             
             int a = currentHamiltonianCycle[nextVertex];
             currentHamiltonianCycle[nextVertex] = currentHamiltonianCycle[j];
             currentHamiltonianCycle[j] = a;
         }
         
         /*
          * Find the total cost of the current Hamiltonian cycle.
          */
         double currentTotalCost = table[currentHamiltonianCycle[0]][currentHamiltonianCycle[numberOfVertices - 1]];
         for (int z = 0; z < numberOfVertices - 1; z++) {
             currentTotalCost += table[currentHamiltonianCycle[z]][currentHamiltonianCycle[z + 1]];
         }
         
         if (currentTotalCost < lowestTotalCost) {
             lowestTotalCost = currentTotalCost;
             shortestHamiltonianCycle = currentHamiltonianCycle;
         }
     }
     return shortestHamiltonianCycle;
 }

Edit

I've gone through this code with pen and paper for a simple example, and I can't find any problems with the algorithm implementation. Based on this, it seems to me that it should work in the general case.


Edit 2

I have tested my implementation with the following mock example:

double[][] table = {{0, 2.3, 1.8, 4.5}, {2.3, 0, 0.4, 0.1}, 
                {1.8, 0.4, 0, 1.3}, {4.5, 0.1, 1.3, 0}}; 

It seems to produce the expected output for the nearest neighbour algorithm, which is 3 -> 1 -> 2 -> 0

I am now wondering whether the auto-marking program is incorrect, or whether it's just that my implementation does not work in the general case.

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

As I have stated in my comments, I see one basic problem with the algorithm itself:

  • It will NOT properly permute the towns, but always work in sequence (A-B-C-D-A-B-C-D, start anywhere and take 4)

To prove that problem, I wrote the following code for testing and setting up simple and advanced examples.

  • Please first configure it via the static public final constants, before you change the code itself.
  • Focusing on the simple example: if the algorithm worked fine, the result would always be either A-B-C-D or D-C-B-A.
  • But as you can observe with the output, the algorithm will not select the (globally) best tour, because it does its permutations of tested towns wrong.

I've added in my own Object-Oriented implementation to showcase:

  • problems with selections, which is really hard to do properly in ONE method all at once
  • how the OO style has its advantages
  • that proper testing/developing is quite easy to set up and perform (I'm not even using Unit tests here, that would be the next step to verify/validate algorithms)

Code:

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;

public class TSP_NearestNeighbour {



    static public final int NUMBER_OF_TEST_RUNS = 4;

    static public final boolean GENERATE_SIMPLE_TOWNS = true;

    static public final int NUMBER_OF_COMPLEX_TOWNS         = 10;
    static public final int DISTANCE_RANGE_OF_COMPLEX_TOWNS = 100;



    static private class Town {
        public final String Name;
        public final int    X;
        public final int    Y;
        public Town(final String pName, final int pX, final int pY) {
            Name = pName;
            X = pX;
            Y = pY;
        }
        public double getDistanceTo(final Town pOther) {
            final int dx = pOther.X - X;
            final int dy = pOther.Y - Y;
            return Math.sqrt(Math.abs(dx * dx + dy * dy));
        }
        @Override public int hashCode() { // not really needed here
            final int prime = 31;
            int result = 1;
            result = prime * result + X;
            result = prime * result + Y;
            return result;
        }
        @Override public boolean equals(final Object obj) {
            if (this == obj) return true;
            if (obj == null) return false;
            if (getClass() != obj.getClass()) return false;
            final Town other = (Town) obj;
            if (X != other.X) return false;
            if (Y != other.Y) return false;
            return true;
        }
        @Override public String toString() {
            return Name + " (" + X + "/" + Y + ")";
        }
    }

    static private double[][] generateDistanceTable(final ArrayList<Town> pTowns) {
        final double[][] ret = new double[pTowns.size()][pTowns.size()];
        for (int outerIndex = 0; outerIndex < pTowns.size(); outerIndex++) {
            final Town outerTown = pTowns.get(outerIndex);

            for (int innerIndex = 0; innerIndex < pTowns.size(); innerIndex++) {
                final Town innerTown = pTowns.get(innerIndex);

                final double distance = outerTown.getDistanceTo(innerTown);
                ret[outerIndex][innerIndex] = distance;
            }
        }
        return ret;
    }



    static private ArrayList<Town> generateTowns_simple() {
        final Town a = new Town("A", 0, 0);
        final Town b = new Town("B", 1, 0);
        final Town c = new Town("C", 2, 0);
        final Town d = new Town("D", 3, 0);
        return new ArrayList<>(Arrays.asList(a, b, c, d));
    }
    static private ArrayList<Town> generateTowns_complex() {
        final ArrayList<Town> allTowns = new ArrayList<>();
        for (int i = 0; i < NUMBER_OF_COMPLEX_TOWNS; i++) {
            final int randomX = (int) (Math.random() * DISTANCE_RANGE_OF_COMPLEX_TOWNS);
            final int randomY = (int) (Math.random() * DISTANCE_RANGE_OF_COMPLEX_TOWNS);
            final Town t = new Town("Town-" + (i + 1), randomX, randomY);
            if (allTowns.contains(t)) { // do not allow different towns at same location!
                System.out.println("Towns colliding at " + t);
                --i;
            } else {
                allTowns.add(t);
            }
        }
        return allTowns;
    }
    static private ArrayList<Town> generateTowns() {
        if (GENERATE_SIMPLE_TOWNS) return generateTowns_simple();
        else return generateTowns_complex();
    }



    static private void printTowns(final ArrayList<Town> pTowns, final double[][] pDistances) {
        System.out.println("Towns:");
        for (final Town town : pTowns) {
            System.out.println("	" + town);
        }

        System.out.println("Distance Matrix:");
        for (int y = 0; y < pDistances.length; y++) {
            System.out.print("	");
            for (int x = 0; x < pDistances.length; x++) {
                System.out.print(pDistances[y][x] + " (" + pTowns.get(y).Name + "-" + pTowns.get(x).Name + ")" + "	");
            }
            System.out.println();
        }
    }



    private static void testAlgorithm() {
        final ArrayList<Town> towns = generateTowns();

        for (int i = 0; i < NUMBER_OF_TEST_RUNS; i++) {
            final double[][] distances = generateDistanceTable(towns);
            printTowns(towns, distances);

            {
                final int[] path = tspnn(distances);
                System.out.println("tspnn Path:");
                for (int pathIndex = 0; pathIndex < path.length; pathIndex++) {
                    final Town t = towns.get(pathIndex);
                    System.out.println("	" + t);
                }
            }
            {
                final ArrayList<Town> path = tspnn_simpleNN(towns);
                System.out.println("tspnn_simpleNN Path:");
                for (final Town t : path) {
                    System.out.println("	" + t);
                }
                System.out.println("
");
            }

            // prepare for for next run. We do this at the end of the loop so we can only print first config
            Collections.shuffle(towns);
        }

    }

    public static void main(final String[] args) {
        testAlgorithm();
    }



    /*
     * Returns the shortest tour found by exercising the NN algorithm
     * from each possible starting city in table.
     * table[i][j] == table[j][i] gives the cost of travel between City i and City j.
     */
    public static int[] tspnn(final double[][] table) {

        // number of vertices
        final int numberOfVertices = table.length;
        // the Hamiltonian cycle built starting from vertex i
        final int[] currentHamiltonianCycle = new int[numberOfVertices];
        // the lowest total cost Hamiltonian cycle
        double lowestTotalCost = Double.POSITIVE_INFINITY;
        //  the shortest Hamiltonian cycle
        int[] shortestHamiltonianCycle = new int[numberOfVertices];

        // consider each vertex i as a starting point
        for (int i = 0; i < numberOfVertices; i++) {
            /*
             * Consider all vertices that are reachable from the starting point i,
             * thereby creating a new current Hamiltonian cycle.
             */
            for (int j = 0; j < numberOfVertices; j++) {
                /*
                 * The modulo of the sum of i and j allows us to account for the fact
                 * that Java indexes arrays from 0.
                 */
                currentHamiltonianCycle[j] = (i + j) % numberOfVertices;
            }
            for (int j = 1; j < numberOfVertices - 1; j++) {
                int nextVertex = j;
                for (int p = j + 1; p < numberOfVertices; p++) {
                    if (table[currentHamiltonianCycle[j - 1]][currentHamiltonianCycle[p]] < table[currentHamiltonianCycle[j - 1]][currentHamiltonianCycle[nextVertex]]) {
                        nextVertex = p;
                    }
                }

                final int a = currentHamiltonianCycle[nextVertex];
                currentHamiltonianCycle[nextVertex] = currentHamiltonianCycle[j];
                currentHamiltonianCycle[j] = a;
            }

            /*
             * Find the total cost of the current Hamiltonian cycle.
             */
            double currentTotalCost = table[currentHamiltonianCycle[0]][currentHamiltonianCycle[numberOfVertices - 1]];
            for (int z = 0; z < numberOfVertices - 1; z++) {
                currentTotalCost += table[currentHamiltonianCycle[z]][currentHamiltonianCycle[z + 1]];
            }

            if (currentTotalCost < lowestTotalCost) {
                lowestTotalCost = currentTotalCost;
                shortestHamiltonianCycle = currentHamiltonianCycle;
            }
        }
        return shortestHamiltonianCycle;
    }



    /**
     * Here come my basic implementations.
     * They can be heavily (heavily!) improved, but are verbose and direct to show the logic behind them
     */



    /**
     * <p>example how to implement the NN solution th OO way</p>
     * we could also implement
     * <ul>
     * <li>a recursive function</li>
     * <li>or one with running counters</li>
     * <li>or one with a real map/route objects, where further optimizations can take place</li>
     * </ul>
     */
    public static ArrayList<Town> tspnn_simpleNN(final ArrayList<Town> pTowns) {
        ArrayList<Town> bestRoute = null;
        double bestCosts = Double.MAX_VALUE;

        for (final Town startingTown : pTowns) {
            //setup
            final ArrayList<Town> visitedTowns = new ArrayList<>(); // ArrayList because we need a stable index
            final HashSet<Town> unvisitedTowns = new HashSet<>(pTowns); // all towns are available at start; we use HashSet because we need fast search; indexing plays not role here

            // step 1
            Town currentTown = startingTown;
            visitedTowns.add(currentTown);
            unvisitedTowns.remove(currentTown);

            // steps 2-n
            while (unvisitedTowns.size() > 0) {
                // find nearest town
                final Town nearestTown = findNearestTown(currentTown, unvisitedTowns);
                if (nearestTown == null) throw new IllegalStateException("Something in the code is wrong...");

                currentTown = nearestTown;
                visitedTowns.add(currentTown);
                unvisitedTowns.remove(currentTown);
            }

            // selection
            final double cost = getCostsOfRoute(visitedTowns);
            if (cost < bestCosts) {
                bestCosts = cost;
                bestRoute = visitedTowns;
            }
        }
        return bestRoute;
    }



    static private Town findNearestTown(final Town pCurrentTown, final HashSet<Town> pSelectableTowns) {
        double minDist = Double.MAX_VALUE;
        Town minTown = null;

        for (final Town checkTown : pSelectableTowns) {
            final double dist = pCurrentTown.getDistanceTo(checkTown);
            if (dist < minDist) {
     

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...