from dis import Instruction
from hashlib import blake2b
import re
from tabnanny import check
import numpy as np
import sys


# f = open('input.txt', 'r')
# f = open('test.txt', 'r')
f = open('test2.txt', 'r')
content = f.read()
lines = content.splitlines()

S_location = (0, 0)
for count, line in enumerate(lines):
    for ch in range(len(line)):
        if line[ch] == "S":
            S_location = (count, ch)
            break
    if S_location != (0, 0):
        break

locations = set()
odd_set = set()
even_set = set()
locations.add((S_location))

def get_num_steps_to_sat(loc_set, num_steps, tiles):
    bleed_over_points = [[]]*4
    second_to_last_count = len(loc_set)
    last_count = len(loc_set)
    # if list(loc_set)[0][0] == 0:
    #     bleed_over_points[0] = [None]
    # if list(loc_set)[0][1] == 0:
    #     bleed_over_points[2] = [None]
    # if list(loc_set)[0][0] >= len(lines[0]):
    #     bleed_over_points[1] = [None]
    # if list(loc_set)[0][1] >= len(lines[0]):
    #     bleed_over_points[3] = [None]
    count = 0
    while count < num_steps:
        new_set = set()
        # print(locations)
        for loc in loc_set:
            for index in [0, 1]:
                for offset in [-1, 1]:
                    location = list(loc)
                    new_tile = list(tiles)
                    add = (offset + location[index]) / len(lines)
                    if add < 0:
                        add = -1
                    elif add > 1:
                        add = 1
                    new_tile[index] += add
                    location[index] = (offset + location[index])%len(lines)
                    if loc[index]+offset >= 0:
                        if loc[index]+offset < len(lines[0]):
                            if lines[location[0]][location[1]] != "#":
                                new_set.add(tuple(location))
                        else:
                            # if len(bleed_over_points[index*2 + 1]) == 0:
                            bleed_over_points[index*2 + 1][location, count+1, tuple(new_tile)]
                    else:
                        # if len(bleed_over_points[index*2 + 0]) == 0:
                        bleed_over_points[index*2 + 0] = [location, count+1, tuple(new_tile)]

        
        loc_set = new_set.copy()
        if second_to_last_count == len(loc_set):
            # print((second_to_last_count, last_count))
            if (num_steps-count) % 2 == 0:
                return count, bleed_over_points, last_count
            else: 
                return count, bleed_over_points, second_to_last_count
        
        count += 1
        second_to_last_count = last_count
        last_count = len(loc_set)
    return 0, bleed_over_points, len(loc_set)
        
def get_positions(loc_set, num_steps, tiles):
    bleed_over_points = [[]]*4
    second_to_last_count = len(loc_set)
    if list(loc_set)[0][0] == 0:
        bleed_over_points[0] = [None]
    if list(loc_set)[0][1] == 0:
        bleed_over_points[2] = [None]
    if list(loc_set)[0][0] >= len(lines[0]):
        bleed_over_points[1] = [None]
    if list(loc_set)[0][1] >= len(lines[0]):
        bleed_over_points[3] = [None]
    last_count = len(loc_set)
    count = 1
    while count < num_steps:
        new_set = set()
        # print(locations)
        for loc in loc_set:
            for index in [0, 1]:
                for offset in [-1, 1]:
                    location = list(loc)
                    location[index] = (offset + location[index])%len(lines)
                    new_tile = list(tiles)
                    new_tile[index] += int((offset + location[index])/len(lines))
                    if loc[index]+offset >= 0:
                        if loc[index]+offset < len(lines[0]):
                            if lines[location[0]][location[1]] != "#":
                                new_set.add(tuple(location))
                        else:
                            bleed_over_points[index*2 + 1] = [*location, count, tuple(new_tile)]
                    else:
                        bleed_over_points[index*2 + 0] = [*location, count, tuple(new_tile)]

        
        loc_set = new_set.copy()
        if second_to_last_count == len(loc_set):
            if (num_steps-count) % 2 == 0:
                return last_count
            else: 
                return second_to_last_count
        
        count += 1
        second_to_last_count = last_count
        last_count = len(loc_set)
        
    return len(loc_set)

starting_set = set()

# for i in range(2650):
#     new_set = set()
#     print(len(locations))
#     for loc in locations:
#         for index in [0, 1]:
#             for offset in [-1, 1]:
#                 location = list(loc)
#                 location[index] += offset
#                 set_of_int = even_set if i%2 == 0 else odd_set
#                 # if tuple(location) not in set_of_int and lines[location[0]%len(lines)][location[1]%len(lines[0])] != "#":
#                     set_of_int.add(tuple(location))
#                     new_set.add(tuple(location))
#     locations = new_set.copy()


# 
# print(get_num_steps_to_sat(locations))

# for starting_loc in ((0,0),(0,len(lines)-1),(len(lines[0])-1,0),(len(lines[0])-1,len(lines)-1)):
#     locations = set()
#     locations.add(starting_loc)
#     print(get_num_steps_to_sat(locations))

# for starting_loc in ((0,S_location[1]),(len(lines)-1,S_location[1]),(S_location[0],0),(S_location[0],len(lines[0])-1)):
#     locations = set()
#     locations.add(starting_loc)
#     print(get_num_steps_to_sat(locations))


print(S_location)
starting_set = set()
starting_set.add(S_location)

starting_points = set()
starting_points.add(S_location)

tile_set = set()
tile_set.add((0,0))

print(get_num_steps_to_sat(starting_points, 16, (0,0)))

sys.setrecursionlimit(10**6)

def rec_fun(set_of_interest, num, tile):
    sum = 0
    _, bleed_overs, count = get_num_steps_to_sat(set_of_interest, num, tile)
    print("bleed over", bleed_overs)
    print("count", count)
    sum += count
    # print(bleed_overs)
    for point in bleed_overs:
        if len(point) > 0 and point[2] not in tile_set:
            print(point)
            tile_set.add(point[2])
            new_set = set()
            new_set.add(tuple(point[0]))
            
            sum += rec_fun(new_set, num-point[1], point[2])
        
    return sum
            

print(rec_fun(starting_points, 500, (0,0)))
# for i in starting_set:
#     sat_num, bleed_overs = get_num_steps_to_sat(locations)
#     for point in bleed_overs:


# for starting_loc in ((0,0),(0,len(lines)-1),(len(lines[0])-1,0),(len(lines[0])-1,len(lines)-1)):
#     locations = set()
#     locations.add(starting_loc)
#     print(get_num_steps_to_sat(locations, 500))

# for starting_loc in ((0,S_location[1]),(len(lines)-1,S_location[1]),(S_location[0],0),(S_location[0],len(lines[0])-1)):
#     locations = set()
#     locations.add(starting_loc)
#     print(get_num_steps_to_sat(locations, 500))


# print(S_location)

f.close()