import numpy as np

DELTA_TIME = 0.05
GRAV_CONST = 6.674  # m^3*s^2/kg


class universe():
    def __init__(self):
        self.matA = None
        self.matB = None
        self.matC = None
        self.line_array = []
        self.mass_array = []
        self.radius_array = []
        self.force_direction = []
        self.positions = []
        self.velocities = []
        self.accelerations = []
        self.particle_list = []
        self.sprites = []

        print("New universe has been created")

    def __repr__(self):
        return "Your universe contains the following particles: " +str(self.particle_list)

    def _recalculate_matA(self):
        #
        #      | 0   m2   m3   m4  |
        #      |                               |
        #      | m1  0    m3   m4  |
        #  A = |                               | * G
        #      | m1  m2   0    m4  |
        #      |                               |
        #      | m1  m2   m3   0   |
        #      |                               |
        #
        temp_array = np.matrix(self.mass_array)
        self.matA = np.matmul(np.ones(temp_array.T.shape), temp_array)
        self.matA = GRAV_CONST * self.matA
        self.matA = self.matA - np.multiply(self.matA, np.eye(*self.matA.shape))
        print(self.matA)

    def _recalculate_matB(self):
        #      | 1          [x1-x2]^2   [x1-x3]^2   [x1-x4]^2  |
        #      |                                               |
        #      | [x2-x1]^2  1           [x2-x3]^2   [x2-x4]^2  |
        #  B = |                                               |
        #      | [x3-x1]^2  [x3-x2]^2   1           [x3-x4]^2  |
        #      |                                               |
        #      | [x4-x1]^2  [x4-x2]^2   [x4-x3]^2   1          |
        #      |                                               |

        temp_array = np.tile(np.array(self.positions), (len(self.positions), 1, 1))
        self.force_direction = temp_array - np.transpose(temp_array, axes=(1, 0, 2))
        self.matB = np.empty(self.force_direction.shape[0:2])
        for counti, i in enumerate(self.force_direction):
            for countj, j in enumerate(i):
                self.matB[counti, countj] = np.dot(j, j)
                if self.matB[counti, countj] == 0.0:
                    self.matB[counti, countj] = 1.0
                self.force_direction[counti, countj] = j / np.sqrt(self.matB[counti, countj])

    def _recalculate_matC(self):
        #      | 0          [r1+r2]^2   [r1+r3]^2   [r1+r4]^2  |
        #      |                                               |
        #      | [r2+r1]^2  0           [r2+r3]^2   [r2+r4]^2  |
        #  C = |                                               |
        #      | [r3+r1]^2  [r3+r2]^2   0           [r3+r4]^2  |
        #      |                                               |
        #      | [r4+r1]^2  [r4+r2]^2   [r4+r3]^2   0          |
        #      |                                               |

        temp_array = np.matrix(self.radius_array)
        self.matC = np.matmul(np.ones(temp_array.T.shape), temp_array) + temp_array.T
        self.matC = np.multiply(self.matC, self.matC)
        # print("Matrix C:", self.matC)

    def add_particle(self, particle):
        self.mass_array.append(particle.mass)
        self.positions.append(particle.position)
        self.velocities.append(particle.velocity)
        self.accelerations.append(particle.acceleration)
        self.particle_list.append(particle)
        self.radius_array.append(particle.radius)
        self.sprites.append(particle.sprite)
        # print("Added new particle")
        self._recalculate_matA()
        self._recalculate_matB()
        self._recalculate_matC()

    def add_line(self, line):
        self.line_array.append(line)

    # def process_collisions(self):

    # def update_gravity(self):

    def process(self):

        # Apply gravity
        temp = self.matA/self.matB
        for counti in range(temp.shape[0]):
            accel = np.array([0.0, 10.0])
            for countj in range(temp.shape[1]):
                accel += temp[counti, countj] * self.force_direction[counti, countj]

            self.accelerations[counti] = accel
            self.velocities[counti] = self.accelerations[counti] * DELTA_TIME + self.velocities[counti]
            self.positions[counti] = self.velocities[counti] * DELTA_TIME + self.positions[counti]
            self.sprites[counti].update_pos(self.positions[counti]-self.radius_array[counti])
        self._recalculate_matB()

        # Check for collisions...
        for counti in range(temp.shape[0]):
            # ...with other particles
            for countj in range(counti+1, temp.shape[0]):
                if (self.matB[counti, countj] <= self.matC[counti, countj]):
                    direction_vector = self.positions[countj] - self.positions[counti]
                    dist = np.linalg.norm(direction_vector)
                    direction_vector = direction_vector / dist
                    temp_vel_1 = np.dot(self.velocities[counti], direction_vector)*(direction_vector)
                    leftover_vel_1 = self.velocities[counti] - temp_vel_1
                    temp_vel_2 = np.dot(self.velocities[countj], direction_vector)*(direction_vector)
                    leftover_vel_2 = self.velocities[countj] - temp_vel_2
                    m1_2 = self.mass_array[counti]*2  
                    m2_2 = self.mass_array[countj]*2  
                    m1_plus_m2 = self.mass_array[counti]+self.mass_array[countj]
                    m1_minus_m2 = self.mass_array[counti] - self.mass_array[countj]
                    self.velocities[countj] = (0.9999*m1_2*temp_vel_1-m1_minus_m2*temp_vel_2)/m1_plus_m2 + leftover_vel_2
                    self.velocities[counti] = (0.9999*m1_minus_m2*temp_vel_1+m2_2*temp_vel_2)/m1_plus_m2 + leftover_vel_1

                    # self.positions[countj] = self.velocities[countj] * DELTA_TIME + self.positions[countj]
                    # self.positions[counti] = self.velocities[counti] * DELTA_TIME + self.positions[counti]

                    self.positions[countj] = self.positions[countj] + direction_vector*((np.sqrt(self.matC[counti,countj])-np.sqrt(self.matB[counti,countj]))*self.mass_array[counti]/m1_plus_m2)
                    self.positions[counti] = self.positions[counti] - direction_vector*((np.sqrt(self.matC[counti,countj])-np.sqrt(self.matB[counti,countj]))*self.mass_array[countj]/m1_plus_m2)

            # with lines
            for line in self.line_array:
                vec1 = line.position - np.array(self.positions[counti])
                dist = np.dot(vec1, line.vector)*line.vector
                norm = np.linalg.norm(dist)
                direction = dist/norm
                if(norm < self.radius_array[counti]):
                    # print("HIT!")
                    speed = np.dot(self.velocities[counti], line.vector)*line.vector
                    self.velocities[counti] = self.velocities[counti] - 2*np.linalg.norm(speed)*direction

                    # self.positions[counti] = self.velocities[counti] * DELTA_TIME + self.positions[counti]
                    self.positions[counti] = self.positions[counti]-direction*(self.radius_array[counti]-norm)