#ifndef NINETAILMODEL_H
#define NINETAILMODEL_H
#include <iostream>
#include "Graph.h"
using namespace std;
const int NUMBER_OF_NODES = 512;
class NineTailModel
{
public:
NineTailModel();
int getIndex(vector<char>& node) const;
vector<char> getNode(int index) const;
vector<int> getShortestPath(int nodeIndex) const;
void printNode(vector<char>& node) const;
protected:
Tree* tree;
vector<Edge> getEdges() const;
int getFlippedNode(vector<char>& node, int position) const;
void flipACell(vector<char>& node, int row, int column) const;
};
NineTailModel::NineTailModel()
{
vector<Edge> edges = getEdges();
Graph<int> graph(NUMBER_OF_NODES, edges);
tree = new Tree(graph.bfs(511));
}
vector<Edge> NineTailModel::getEdges() const
{
vector<Edge> edges;
for (int u = 0; u < NUMBER_OF_NODES; u++)
{
for (int k = 0; k < 9; k++)
{
vector<char> node = getNode(u);
if (node[k] == 'H')
{
int v = getFlippedNode(node, k);
edges.push_back(Edge(v, u));
}
}
}
return edges;
}
int NineTailModel::getFlippedNode(vector<char>& node, int position)
const
{
int row = position / 3;
int column = position % 3;
flipACell(node, row, column);
flipACell(node, row - 1, column);
flipACell(node, row + 1, column);
flipACell(node, row, column - 1);
flipACell(node, row, column + 1);
return getIndex(node);
}
void NineTailModel::flipACell(vector<char>& node,
int row, int column) const
{
if (row >= 0 && row <= 2 && column >= 0 && column <= 2)
{
if (node[row * 3 + column] == 'H')
node[row * 3 + column] = 'T';
else
node[row * 3 + column] = 'H';
}
}
int NineTailModel::getIndex(vector<char>& node) const
{
int result = 0;
for (int i = 0; i < 9; i++)
if (node[i] == 'T')
result = result * 2 + 1;
else
result = result * 2 + 0;
return result;
}
vector<char> NineTailModel::getNode(int index) const
{
vector<char> result(9);
for (int i = 0; i < 9; i++)
{
int digit = index % 2;
if (digit == 0)
result[8 - i] = 'H';
else
result[8 - i] = 'T';
index = index / 2;
}
return result;
}
vector<int> NineTailModel::getShortestPath(int nodeIndex) const
{
return tree->getPath(nodeIndex);
}
void NineTailModel::printNode(vector<char>& node) const
{
for (int i = 0; i < 9; i++)
if (i % 3 != 2)
cout << node[i];
else
cout << node[i] << endl;
cout << endl;
}
#endif