import thread, profile, sys, os

start_new_thread = thread.start_new_thread
main_thread_profiler = profile.Profile()
threaded_profilers = [main_thread_profiler, ]


def profiled_start_new_thread(func, args):
    def f(profile=profile, func=func, args=args):
        profiler = profile.Profile()
        threaded_profilers.append(profiler)
        profiler.runcall(func, *args)
    
    start_new_thread(f, ())

thread.start_new_thread = profiled_start_new_thread
    

def mergeTimings(profilers):
    """Merge the timings from all the timings in a list of profilers"""
    result = {}
    for p in profilers:
        for name, (n1, n2, n3, n4, d) in p.timings.items():
            result.setdefault(name, [0, 0, 0, 0, {}])
            result[name][0] += n1
            result[name][1] += n2
            result[name][2] += n3
            result[name][3] += n4
            for f in d.keys():
                result[name][4].setdefault(f, 0)
                result[name][4][f] += d[f]
    return result


# When invoked as main program, invoke the profiler on a script
if __name__ == '__main__':
    if not sys.argv[1:]:
        print "usage: threaded_profile.py scriptfile [arg] ..."
        sys.exit(2)

    filename = sys.argv[1]    # Get script filename

    del sys.argv[0]        # Hide "threaded_profile.py" from argument list

    # Insert script directory in front of module search path
    sys.path.insert(0, os.path.dirname(filename))

    # profile the passed file
    try:
        prof = main_thread_profiler.run('execfile(' + `filename` + ')')
    except SystemExit:
        pass
    
    # concatenate stats from all the profilers
    main_thread_profiler.timings = mergeTimings(threaded_profilers)
    
    # print out the summarized stats
    main_thread_profiler.print_stats()

