import sys
import re
import copy
import os
sys.path.insert(1, os.path.abspath('../../'))
from python_tools.aoc_utils import *

# +---+---+---+
# | 7 | 8 | 9 |
# +---+---+---+
# | 4 | 5 | 6 |
# +---+---+---+
# | 1 | 2 | 3 |
# +---+---+---+
#     | 0 | A |
#     +---+---+
#     +---+---+
#     | ^ | A |
# +---+---+---+
# | < | v | > |
# +---+---+---+


lg = {
        "7": [0,0],
        "8": [0,1],
        "9": [0,2],
        "4": [1,0],
        "5": [1,1],
        "6": [1,2],
        "1": [2,0],
        "2": [2,1],
        "3": [2,2],
        "dead": [3,0],
        "0": [3,1],
        "A": [3,2],
      }
sg = {
        "dead": [0,0],
        "^": [0,1],
        "A": [0,2],
        "<": [1,0],
        "V": [1,1],
        ">": [1,2],
      }

G = {}

def move_line(line, s, idx):
    global sg
    global lg
    global G
    if idx not in G:
        G[idx] = {}
    if len(s) == 0:
        return len(line)
    else:
        if line not in G[idx]:
            sum = 0
            g = lg
            if idx != 0:
                g = sg
            current_pos = s[0]
            for c in line:
                product = 1
                if type(c) == list:
                    product = c[0]
                    c = c[1]
                # if G[
                results = []
                i = g[c].copy()
                for k in [[0,1],[1,0]]:
                    l = [["^","V"],["<",">"]]
                    if (
                    i[k[0]] != g["dead"][k[0]] or
                    current_pos[k[1]] != g["dead"][k[1]]
                       ):
                        results.append(0)
                        todo = ''
                        for j in k:
                            diff = i[j] - current_pos[j]
                            if diff < 0:
                            # if c not in G[idx]:
                                todo+=abs(diff)*l[j][0]
                                # results[-1] += abs(diff)*move_line(l[j][0], s[1:], idx+1)
                            elif diff > 0:
                                # results[-1] += abs(diff)*move_line(l[j][1], s[1:], idx+1)
                                todo+=abs(diff)*l[j][1]
                        # results[-1] += move_line("A", s[1:], idx+1)
                        todo += "A"
                        results[-1]+=move_line(todo, s[1:], idx+1)

                sum+=min(results)-1+product
                current_pos = i
            G[idx][line] = sum
        return G[idx][line]


def part_one(input):
    global G
    G = {}
    sum = 0
    for i in input:
        rs = [
                lg["A"],
                sg["A"],
                sg["A"],
                ]
        sum += int(re.findall(r"\d+", i)[0]) * move_line(i, rs, 0)
    print("Part1: ", sum)




def part_two(input):
    global G
    G = {}
    sum = 0
    for i in input:
        rs = [
                lg["A"],
                sg["A"], sg["A"], sg["A"], sg["A"], sg["A"],
                sg["A"], sg["A"], sg["A"], sg["A"], sg["A"],
                sg["A"], sg["A"], sg["A"], sg["A"], sg["A"],
                sg["A"], sg["A"], sg["A"], sg["A"], sg["A"],
                sg["A"], sg["A"], sg["A"], sg["A"], sg["A"],
                ]
        sum += int(re.findall(r"\d+", i)[0]) * move_line(i, rs, 0)
    print("Part2: ", sum)

def main():
    contents = file2list("input.txt")
    part_one(contents)
    part_two(contents)

if __name__ == "__main__":
    main()

# 029A: <vA<AA>>^AvAA<^A>A<v<A>>^AvA^A<vA>^A<v<A>^A>AAvA^A<v<A>A>^AAAvA<^A>A
# 980A: <v<A>>^AAAvA^A<vA<AA>>^AvAA<^A>A<v<A>A>^AAAvA<^A>A<vA>^A<A>A
# 179A: <v<A>>^A<vA<A>>^AAvAA<^A>A<v<A>>^AAvA^A<vA>^AA<A>A<v<A>A>^AAAvA<^A>A
# 456A: <v<A>>^AA<vA<A>>^AAvAA<^A>A<vA>^A<A>A<vA>^A<A>A<v<A>A>^AAvA<^A>A
# 379A: <v<A>>^AvA^A<vA<AA>>^AAvA<^A>AAvA^A<vA>^AA<A>A<v<A>A>^AAAvA<^A>A