import sys
import os
sys.path.insert(1, os.path.abspath('../../'))
from python_tools.aoc_utils import *


row_limit=0
col_limit=0

def check_if_spells_xmas(input, base_row, base_col, dir_row, dir_col):
    if ((base_row+dir_row*4 > row_limit) or (base_row+dir_row*3 < 0) or
        (base_col+dir_col*4 > col_limit) or (base_col+dir_col*3 < 0)):
        return False
    for count, i in enumerate('XMAS'):
        if input[base_row + count*dir_row][base_col + count*dir_col] != i:
            return False

    return True

def get_string(input, row, col, dir):
    x = ""
    for count in [-1,0,1]:
        x += input[row+(dir*count)][col+count]
    return x

def check_if_spells_mas(input, base_row, base_col):
    if ((base_row+1 >= row_limit) or (base_row-1 < 0) or
        (base_col+1 >= col_limit) or (base_col-1 < 0)):
        return False
    cris = False
    cross = False
    test_str = get_string(input, base_row, base_col, 1)
    cris = (test_str == 'MAS' or test_str == 'SAM')
    test_str = get_string(input, base_row, base_col, -1)
    cross = (test_str == 'MAS' or test_str == 'SAM')
    return (cris and cross)


def part_one(input):
    pt_1 = 0
    global row_limit
    global col_limit
    row_limit=len(input)
    col_limit=len(input[0])
    for row in range(len(input)):
        for col in range(len(input[row])):
            if input[row][col] == 'X':
                for i in [-1,0,1]:
                    for j in [-1,0,1]:
                        if check_if_spells_xmas(input, row, col, i, j):
                            pt_1 += 1
    print("Part One: ", pt_1)

def part_two(input):
    pt_2 = 0
    global row_limit
    global col_limit
    row_limit=len(input)
    col_limit=len(input[0])
    for row in range(len(input)):
        for col in range(len(input[row])):
            if input[row][col] == 'A':
                if check_if_spells_mas(input, row, col):
                    pt_2 += 1
    print("Part One: ", pt_2)

def main():
    contents = file2list("input.txt")
    part_one(contents)
    part_two(contents)

if __name__ == "__main__":
    main()