1  import java.util.*;
  2  
  3  public class NineTailModel {
  4    public final static int NUMBER_OF_NODES = 512;
  5    protected UnweightedGraph<Integer>.SearchTree tree; 
  6  
  7    /** Construct a model */
  8    public NineTailModel() {
  9      // Create edges
 10      List<Edge> edges = getEdges();
 11      
 12      // Create a graph
 13      UnweightedGraph<Integer> graph = new UnweightedGraph<>(
 14        edges, NUMBER_OF_NODES); 
 15  
 16      // Obtain a BSF tree rooted at the target node
 17      tree = graph.bfs(511);
 18    }
 19  
 20    /** Create all edges for the graph */
 21    private List<Edge> getEdges() {
 22      List<Edge> edges =
 23        new ArrayList<>(); // Store edges
 24  
 25      for (int u = 0; u < NUMBER_OF_NODES; u++) {
 26        for (int k = 0; k < 9; k++) {
 27          char[] node = getNode(u); // Get the node for vertex u
 28          if (node[k] == 'H') {
 29            int v = getFlippedNode(node, k);
 30            // Add edge (v, u) for a legal move from node u to node v
 31            edges.add(new Edge(v, u));
 32          }
 33        }
 34      }
 35  
 36      return edges;
 37    }
 38  
 39    public static int getFlippedNode(char[] node, int position) {
 40      int row = position / 3;
 41      int column = position % 3;
 42  
 43      flipACell(node, row, column);
 44      flipACell(node, row - 1, column);
 45      flipACell(node, row + 1, column);
 46      flipACell(node, row, column - 1);
 47      flipACell(node, row, column + 1);
 48  
 49      return getIndex(node);
 50    }
 51  
 52    public static void flipACell(char[] node, int row, int column) {
 53      if (row >= 0 && row <= 2 && column >= 0 && column <= 2) { 
 54        // Within the boundary
 55        if (node[row * 3 + column] == 'H')
 56          node[row * 3 + column] = 'T'; // Flip from H to T
 57        else
 58          node[row * 3 + column] = 'H'; // Flip from T to H
 59      }
 60    }
 61  
 62    public static int getIndex(char[] node) {
 63      int result = 0;
 64  
 65      for (int i = 0; i < 9; i++)
 66        if (node[i] == 'T')
 67          result = result * 2 + 1;
 68        else
 69          result = result * 2 + 0;
 70  
 71      return result;
 72    }
 73  
 74    public static char[] getNode(int index) {
 75      char[] result = new char[9];
 76  
 77      for (int i = 0; i < 9; i++) {
 78        int digit = index % 2;
 79        if (digit == 0)
 80          result[8 - i] = 'H';
 81        else
 82          result[8 - i] = 'T';
 83        index = index / 2;
 84      }
 85  
 86      return result;
 87    }
 88    
 89    public List<Integer> getShortestPath(int nodeIndex) {
 90      return tree.getPath(nodeIndex);
 91    }
 92  
 93    public static void printNode(char[] node) {
 94      for (int i = 0; i < 9; i++)
 95        if (i % 3 != 2)
 96          System.out.print(node[i]);
 97        else
 98          System.out.println(node[i]);
 99  
100      System.out.println();
101    }
102  }