# -*- coding: utf-8 -*-


from random import randint, gauss
from itertools import izip


#from cpp import virtualmachine
#def cpp_vm():
#    vm = virtualmachine()
##    vm.load('binaries/bin1.obf')
#    return vm

class Unviable(Exception):
    pass

class Population(object):
    def __init__(population, desired_size, random_chromosome, fitness_function, survival_ratio, sex):
        chromosomes = []
        for i in range(desired_size):
            chromosome = random_chromosome()
            chromosomes.append((chromosome, fitness_function(chromosome)))
        population.__dict__.update(
            desired_size = desired_size,
            random_chromosome = random_chromosome,
            fitness_function = fitness_function,
            survival_ratio = survival_ratio,
            sex = sex,
            chromosomes = chromosomes,
        )
        population.generation = 0

    def iterate(population):
        population_size = len(population.chromosomes)
        chromosomes_by_fitness = [(fitness, chromosome) for chromosome, fitness in population.chromosomes]
        chromosomes_by_fitness.sort()


        survival_ratio = population.survival_ratio
        best_slice = int(len(chromosomes_by_fitness) * survival_ratio)

        survivors = [(chromosome, fitness) for fitness, chromosome in chromosomes_by_fitness[best_slice:]]
        survivors.sort() # this implies that the genes are ordered in importance

        fitness_function = population.fitness_function
        def breed(father, mother, do):
            father, _ = father
            mother, _ = mother

            len_father = len(father)
            len_mother = len(mother)
            if len_father > len_mother:
                mother = list(mother) + father[:len_mother - len_father]
            chromosome = []
            try:
                sex(father, mother, chromosome)
                if chromosome == father or chromosome == mother:
                    print "chromosome same as parent!"
                    chromosome = population.random_chromosome()
                else:
                    print "father:", father
                    print "mother:", mother
                    print "offspr:", chromosome

            except Unviable:
                print "Unviable!"
                chromosome = population.random_chromosome()
            do((chromosome, fitness_function(chromosome)))

        survivors.append((population.random_chromosome(), None)) #entropy
        new_generation = []
        num_survivors = len(survivors)
        desired_size = population.desired_size
        while desired_size > len(new_generation) + num_survivors:
#            print num_survivors-1
            father = randint(0, num_survivors-1)
            mother = -1
            while mother >= num_survivors or mother < 0:
                mother = int(father + gauss(0, num_survivors))
#                print mother, father, num_survivors
#            print "-----------"
            breed(
                survivors[father],
                survivors[mother],
                new_generation.append
            )

        population.chromosomes = survivors + new_generation
        population.generation += 1
        chromosomes_by_fitness.reverse()
        return chromosomes_by_fitness

# importance :
#time, (positive integer)
#direction, (-inf to +inf, double)
#magnitude (-inf to +inf, double)

def random_chromosome(size):
    def r_chromosome():
        chromosome = []
        for i in xrange(size):
            wait = int(gauss(0,10000))
            if wait<0:
                wait = 0
            chromosome.append((
                wait,
                gauss(0, 0.01), # magnitude
            ))
        return chromosome
    return r_chromosome

def sex(father, mother, chromosome):
    if randint(0,1):
        if randint(0,1):
            print "averaging"
            for father_gene, mother_gene in izip(father, mother):
        #        print "~", father_gene, mother_gene
                f_time, f_magnitude = father_gene
                m_time, m_magnitude = mother_gene
                new_time = ((f_time+m_time)/2)+int(gauss(0,100))
                if new_time<0:
                    new_time = 0
                chromosome.append(
                    (
                        new_time,
                        ((f_magnitude + m_magnitude)/2) + gauss(0, 0.01)
                    ),
                )
        else:
            print "mutation"
            for f_time, f_magnitude in father:
                new_time = (f_time+int( gauss(0,100) ))
                if new_time<0:
                    new_time = 0

                chromosome.append(
                    (
                        new_time,
                        f_magnitude + gauss(0,  0.01)
                    ),
                )
                time = new_time
    else:
        print "recombination"
        split = randint(1, len(father)-1)
        chromosome += father[:split] + mother[split:]

X =0x2
Y =0x3
from math import sqrt

class ChromosomeFlier(object):
    def __init__(flier, plan, visualiser = None):
  #      print "plan", plan
        flier.plan = list(plan)
        flier.next_step(0)
        flier.score = 0.0
        flier.visualiser = visualiser

    def next_step(flier, t):
        if flier.plan:
            wait, thrust = flier.plan.pop(0)
            flier.next_time = t + wait
            flier.thrust = thrust
        else:
            flier.next_time = -1

    def __call__(flier, t, score, fuel_remaining, x, y, *args):
        flier.dx = x-flier.x
        flier.dy = y-flier.y
        flier.x = x
        flier.y = y
        flier.t = t
        if score:
            flier.score = score

        flier.update_scores(t, score, fuel_remaining, x, y, *args)

        if t == flier.next_time:
            thrust = {
                X: flier.dx * flier.thrust,
                Y: flier.dy * flier.thrust,
            }
            flier.next_step(t)
        else:
            thrust = {
                X:0.0,
                Y:0.0
            }
        
        if flier.visualiser:
            flier.visualiser.visualise(t, score, fuel_remaining, x, y, *args)
        
        return thrust
    
    def distance_to(flier, x_, y_):
        return sqrt(((flier.x-x_)**2) + ((flier.y-y_)**2))
        
    def height_diff(flier, target_orbit_radius):
        x,y = flier.x, flier.y
        return target_orbit_radius - sqrt((x*x) + (y*y))

import translated_binaries.bin1

class Hohmann(ChromosomeFlier):
    VM = translated_binaries.bin1.OrbitalVirtualMachine
    port_count = 5

    # 1001
    x = -6557009.3141799988
    y = 0.0

    def __init__(flier, plan, **kwargs):
        super(Hohmann, flier).__init__(plan, **kwargs)
        flier.sum_distance = 0.0
        flier.achieved_radius = 0

    def update_scores(flier, t, score, fuel_remaining, x, y, target_orbit_radius):
        flier.distance = abs(height_diff(target_orbit_radius))
        flier.sum_distance += flier.distance
#        if flier.distance > target_orbit_radius:
#            flier.achieved_radius = 1
            
    def get_score(flier, time_out):
        sum_distance = flier.sum_distance
        if flier.t < time_out:
            sum_distance += (time_out - flier.t) * flier.distance
        return (
            flier.score,
#            flier.achieved_radius,
            1e13 / sum_distance, # need this to be fair and second, not third
        )

import translated_binaries.bin2
import translated_binaries.bin3

class MeetAndGreet(ChromosomeFlier):
    VM = translated_binaries.bin2.OrbitalVirtualMachine
    port_count = 6
    # 2001
    x = -6557009.3141799988
    y = 0.0

    def __init__(flier, plan, **kwargs):
        super(MeetAndGreet, flier).__init__(plan, **kwargs)
        flier.sum_distance = 0.0
        flier.velocity_diff = 0.0
        flier.x_target = None
        flier.orbit_diff = 0.0
        if flier.visualiser:
            flier.point = flier.visualiser.point((255,255,0))

    def update_scores(flier, t, score, fuel_remaining, x, y, x_target, y_target):
        flier.distance = flier.distance_to(x_target, y_target)
        flier.sum_distance += flier.distance

    def get_score(flier, time_out):
        sum_distance = flier.sum_distance
        velocity_diff = flier.velocity_diff
        orbit_diff = flier.orbit_diff
        if flier.t < time_out:
            sum_distance += (time_out - flier.t) * flier.distance
            velocity_diff += (time_out - flier.t) * flier.velocity_diff
            orbit_diff += (time_out - flier.t) * flier.orbit_diff
            
        return (
            flier.score,
            -orbit_diff,
            -velocity_diff,
            -sum_distance, # need this to be fair and second, not third
        )
    
    def __call__(flier, t, score, fuel_remaining, x, y, x_target, y_target):
        result = super(MeetAndGreet, flier).__call__(t, score, fuel_remaining, x, y, x_target, y_target)
        
        if flier.x_target:
            dx_target = x_target - flier.x_target
            dy_target = y_target - flier.y_target
            flier.velocity_diff += abs(flier.dx - dx_target) + abs(flier.dy - dy_target)
        else:
            flier.target_orbit = sqrt((x+x_target)**2 + (y+y_target)**2)

        flier.x_target = x_target
        flier.y_target = y_target
        flier.orbit_diff += abs(flier.height_diff(flier.target_orbit))
                
        if flier.visualiser:
            flier.visualiser.plot_point(t, 
                x - x_target, 
                y - y_target, 
                flier.point
            )
        return result

from simulation import run_simulation

time_out = 40000
def fitness_function(Flier, scenario_id):
    def fitness_function(chromosome):
        print chromosome
        flier = Flier(chromosome)
        run_simulation(
            Flier.VM(),
            scenario_id,
            flier,
            max = time_out
        )
        return flier.get_score(time_out)
    return fitness_function


from visualiser import Visualiser

def visualise(controller):
    def visualisation(t, score, fuel_remaining, x, y, _, *args):
        visualiser.visualise(t, score, fuel_remaining, x, y, _)
        visualiser.log(t, score, fuel_remaining, x, y, _)
        return controller(t, score, fuel_remaining, x, y, _, *args)
    visualisation.port_count = controller.port_count
    return visualisation

def evolve(Flier, scenario, pop_size, no_genes ):
    population = Population(
        pop_size,
        random_chromosome = random_chromosome(no_genes),
        fitness_function = fitness_function(Flier, scenario),
        survival_ratio = 0.5,
        sex = sex,
    )

    best = None
    for i in xrange(100):
        survivors = population.iterate()
        print "Best from generation %i:" % i
        for fitness, chromosome in survivors:
            print fitness, chromosome
        new_best = survivors[0]
        print new_best
        print best
        if best != new_best:
            fitness, chromosome = new_best
            flier = Flier(chromosome, visualiser = Visualiser((500, 500)))
            simulation_score = run_simulation(
                Flier.VM(),
                scenario,
                flier,
                max = time_out
            )
            sum_distance = flier.sum_distance
            if flier.t < time_out:
                sum_distance += (time_out - flier.t) * flier.distance
            print (
                flier.score,
    #            flier.achieved_radius,
                1e13 / sum_distance, # need this to be fair and second, not third
            )
            best = new_best


evolve(MeetAndGreet, 2001, 20, 4)
