import sys
import os
sys.path.insert(1, os.path.abspath('../../'))
from python_tools.aoc_utils import *

import collections as C

G = C.defaultdict(list)
VISITED = set()
LIMITS = [0,0]
CURRENT_PLOT = ()

def add_p_and_ap1(input, cp, a, p):
    ft = input[cp[0]][cp[1]]
    G[CURRENT_PLOT].append(cp)
    VISITED.add(tuple(cp))
    for i in [[1,0],[0,1],[-1,0],[0,-1]]:
        np = [cp[0]+i[0], cp[1]+i[1]]
        if 0 <= np[0] < LIMITS[0] and 0 <= np[1] < LIMITS[1] and tuple(np) not in VISITED:
            if input[np[0]][np[1]] == ft:
                a, p = add_p_and_ap1(input, np, a, p)
    a += 1
    p += 4
    for j in [[1,0],[0,1],[-1,0],[0,-1]]:
        nnp = [cp[0]+j[0], cp[1]+j[1]]
        #check the sides, to see if the perimiter needs to be decrimented
        if 0 <= nnp[0] < LIMITS[0] and 0 <= nnp[1] < LIMITS[1]:
            if input[nnp[0]][nnp[1]] == ft:
                p -= 1
    return a, p

def add_p_and_ap2(input, cp, a, p):
    ft = input[cp[0]][cp[1]]
    G[CURRENT_PLOT].append(cp)
    VISITED.add(tuple(cp))
    for i in [[1,0],[0,1],[-1,0],[0,-1]]:
        np = [cp[0]+i[0], cp[1]+i[1]]
        if 0 <= np[0] < LIMITS[0] and 0 <= np[1] < LIMITS[1] and tuple(np) not in VISITED:
            if input[np[0]][np[1]] == ft:
                a, p = add_p_and_ap2(input, np, a, p)
    a += 1
    p += 4
    for j in [[1,0],[0,1]]:
        for i in [-1, 1]:
            nnp = [cp[0]+j[0], cp[1]+j[1]]
            #check the sides, to see if the perimiter needs to be decrimented
            if 0 <= nnp[0] < LIMITS[0] and 0 <= nnp[1] < LIMITS[1]:
                if input[nnp[0]][nnp[1]] == ft:
                    p -= 1
    return a, p

def part_one(input):
    VISITED.clear()
    p1 = 0
    for row, i in enumerate(input):
        for col, j in enumerate(i):
            if (row, col) not in VISITED:
                CURRENT_PLOT = (row, col)
                a, p = add_p_and_ap1(input, [row, col], 0, 0)
                p1 += a*p
    print("Part 1: ", p1)


def part_two(input):
    VISITED.clear()
    p2 = 0
    for row, i in enumerate(input):
        for col, j in enumerate(i):
            if (row, col) not in VISITED:
                CURRENT_PLOT = (row, col)
                a, p = add_p_and_ap2(input, [row, col], 0, 0)
                p2 += a*p
    print("Part 2: ", p2)

def main():
    f = open("input.txt", 'r')
    contents = [list(x.strip()) for x in f.readlines()]
    LIMITS[0] = len(contents)
    LIMITS[1] = len(contents[0])
    part_one(contents)
    part_two(contents)

if __name__ == "__main__":
    main()