import queue
import numpy as np
from enum import Enum

class Action(Enum):
    West = 0
    North  = 1
    East = 2
    South = 3

class ConnerPoint():
    def __init__(self, no, x, y):
        self.__no = no
        self.__x = x
        self.__y = y

    def __repr__(self):
        return "<Point: [no={}, y={}, x={}]>".format(self.no, self.y, self.x)
    
    @property
    def no(self):
        return self.__no

    @property
    def x(self):
        return self.__x

    @property
    def y(self):
        return self.__y

class StateInfo():
    def __init__(self, no, west, north, east, south):
        self.no = no
        self.connected_no = [west, north, east, south]

class State():
    def __init__(self, row=-1, column=-1):
        self.row = row
        self.column = column

    def __repr__(self):
        return "<State: [{}, {}]>".format(self.row, self.column)

    def clone(self):
        return State(self.row, self.column)

    def __hash__(self):
        return hash((self.row, self.column))

    def __eq__(self, other):
        return self.row == other.row and self.column == other.column

class Maze():
    _START = -1
    _PATH = 0
    _WALL = 1
    _GOAL = 99
    _DEFAULT_REWARD = 0
    _UNCHANGED_REWARD = -1
    _GOAL_REWARD = 10

    @property
    def action_num(self):
        return len(Action)

    @property
    def maze_size(self):
        return self._width
    
    @property
    def posx(self):
        return self._current_state.column

    @property
    def posy(self):
        return self._current_state.row

    @property
    def observation(self):
        #暫定
        observation = np.zeros((self.maze_size,self.maze_size))
        for i in range(0, self.maze_size):
          for j in range(0, self.maze_size):
              color = 1.0 #path
              if i == self._current_state.row and j == self._current_state.column:
                color = 127.0 / 255.0 #プレイヤーは中間色
              elif self._maze_map[i,j] == Maze._PATH:
                color = 0.0
              observation[i,j] = color
        return observation

    def initialize_maze(self, maze_map):
        self._maze_map = np.array(maze_map)
        self._width = self._height = len(self._maze_map[0]) #正方形
        self._search_conner_points()
        self._create_conner_network()
        self._set_start_goal()
    
    def reset(self):
        self._current_state = State(self._start_point.y, self._start_point.x)

    def transit(self, action):
        next_state = self._move(self._current_state, action)
        reward, done = self._reward_func(self._current_state, next_state)
        self._current_state = next_state
        return reward, done

    def print_maze(self):
        """ 迷路を出力する。"""
        for i in range(0, self._height):
            for j in range(0, self._width):
                cell = self._maze_map[i][j]
                if i == self._current_state.row and j == self._current_state.column:
                    print(' P ', end='')
                elif cell == Maze._PATH:
                    print('   ', end='')
                elif cell == Maze._START:
                    print(' S ', end='')
                elif cell == Maze._GOAL:
                    print(' G ', end='')
                elif cell == Maze._WALL:
                    print('###', end='')
            print()

    def _search_conner_points(self):
        if len(self._maze_map) == 0:
            print('not created maze')
            exit()

        conner_points = []
        conner_no = 1
        for i in range(1, self._height):
            for j in range(1, self._width):
                if self._maze_map[i][j] == Maze._PATH or self._maze_map[i][j] == Maze._START or self._maze_map[i][j] == Maze._GOAL :
                    wall_cnt = 0
                    if self._maze_map[i-1][j] == Maze._WALL:
                        wall_cnt += 1
                    if self._maze_map[i+1][j] == Maze._WALL:
                        wall_cnt += 1
                    if self._maze_map[i][j-1] == Maze._WALL:
                        wall_cnt += 1
                    if self._maze_map[i][j+1] == Maze._WALL:
                        wall_cnt += 1

                    #基本的に上下左右に壁が2つあったら通路とする
                    #ただしL字の場合はコーナーとする
                    if (wall_cnt != 2 or 
                    (wall_cnt == 2 and 
                    ((self._maze_map[i-1][j] == Maze._PATH and self._maze_map[i][j-1] == Maze._PATH) or 
                    (self._maze_map[i-1][j] == Maze._PATH and self._maze_map[i][j+1] == Maze._PATH) or
                    (self._maze_map[i+1][j] == Maze._PATH and self._maze_map[i][j-1] == Maze._PATH) or
                    (self._maze_map[i+1][j] == Maze._PATH and self._maze_map[i][j+1] == Maze._PATH)))):
                        conner_points.append(ConnerPoint(conner_no, j, i))
                        conner_no += 1

        # print(conner_points)
        self.conner_points = conner_points

    def _create_conner_network(self):
        state_inf_dic = {}
        for conner in self.conner_points:
            base_i = conner.y
            base_j = conner.x
            west = None
            north = None
            east = None
            south = None

            for j in range(base_j-1, 0, -1):
                if self._maze_map[base_i][j] == Maze._WALL:
                    break
                
                no = self._exist_pos(j, base_i)
                if no > 0:
                    west = no
                    break

            for i in range(base_i-1, 0, -1):
                if self._maze_map[i][base_j] == Maze._WALL:
                    break
                
                no = self._exist_pos(base_j, i)
                if no > 0:
                    north = no
                    break 

            for j in range(base_j+1, self._width):
                if self._maze_map[base_i][j] == Maze._WALL:
                    break
                
                no = self._exist_pos(j, base_i)
                if no > 0:
                    east = no
                    break

            for i in range(base_i+1, self._height):
                if self._maze_map[i][base_j] == Maze._WALL:
                    break
                
                no = self._exist_pos(base_j, i)
                if no > 0:
                    south = no
                    break
            state_inf_dic[State(conner.y, conner.x)] = StateInfo(conner.no, west, north, east, south)
            # print("{} : {}, {}, {}, {}".format(conner.no, west, north, east, south))
        self._state_inf_dic = state_inf_dic

    def _exist_pos(self, x, y):
        for conner in self.conner_points:
            if conner.x == x and conner.y == y:
                return conner.no
        return 0
  
    def _set_start_goal(self):
        self._start_point = self.conner_points[0]
        self._goal_point = self._search_longest_distance_Point(self._start_point, self.conner_points, self._state_inf_dic)[0]
        self._maze_map[self._start_point.y][self._start_point.x] = Maze._START
        self._maze_map[self._goal_point.y][self._goal_point.x] = Maze._GOAL

    def _search_longest_distance_Point(self, start_point, conner_points, state_inf_dic):
        check_visited_no = [False] * (len(conner_points) + 1)
        check_visited_no[start_point.no] = True

        q = queue.Queue()
        q.put(start_point)
        hop_count = 0

        l = [(start_point, hop_count)]
        while not q.empty():
            cp = q.get()
            state_info = state_inf_dic[State(cp.y, cp.x)]
            for connected_no in state_info.connected_no:
                if connected_no is not None and not check_visited_no[connected_no]:
                    hop_count += 1
                    ccp = list(filter(lambda p: p.no == connected_no, conner_points))[0]
                    l.append((ccp, hop_count))
                    q.put(ccp)
                    check_visited_no[connected_no] = True

        return max(l, key=(lambda x: x[1]))

    def _reward_func(self, state, next_state):
        done = False
        reward = Maze._DEFAULT_REWARD
        attribute = self._maze_map[state.row][state.column]
        if attribute == Maze._PATH or attribute == Maze._START:
            if next_state.row == state.row and next_state.column == state.column:
                reward = Maze._UNCHANGED_REWARD
        elif attribute == Maze._GOAL:
            reward = Maze._GOAL_REWARD
            done = True

        return reward, done

    def _move(self, state, action):
        next_state = state.clone()

        state_info = self._state_inf_dic[state]
        next_no = state_info.connected_no[action]
        if next_no is not None:
            next_state = self._no_to_state(next_no)

        return next_state

    def _no_to_state(self, no):
        states = [k for k, v in self._state_inf_dic.items() if v.no == no]
        if len(states) == 0:
            return None
        else:
            #2つ以上の要素を持つことはないはず
            return states[0]