#!/usr/bin/env python3
# coding: utf-8

import json
import pandas
import findspark
findspark.init()
import pyspark
import pyspark.sql
import sys
import gzip

from pyspark import AccumulatorParam
from pyspark.sql.functions import lit
from pyspark.sql import Window
from pyspark.sql.types import ByteType

if len(sys.argv) is not 4:
    print(sys.argv[0] + " {cluster} {tmpdir} {maxram}")
    sys.exit()

cluster=sys.argv[1]

spark = pyspark.sql.SparkSession.builder \
  .appName("figure_9b") \
  .config("spark.driver.maxResultSize", "128g") \
  .config("spark.local.dir", sys.argv[2]) \
  .config("spark.driver.memory", sys.argv[3]) \
  .getOrCreate()
sc = spark.sparkContext

dfepath="/home/claudio/google_2019/collection_events/" + cluster + "/" + cluster + "_collection_events*.json.gz"
#dfepath="/home/claudio/google_2019/collection_events/" + cluster + "/" + cluster + "_test.json"

df = spark.read.json(dfepath)

try:
    df["collection_type"] = df["collection_type"].cast(ByteType())
except:
    df = df.withColumn("collection_type", lit(None).cast(ByteType()))

MICROS = 1000000

def sumrow(l, p, t, c):
    t = t // (MICROS * 60)
    if t < 1:
        t = "<1"
    elif t < 2:
        t = "1-2"
    elif t < 4:
        t = "2-4"
    elif t < 10:
        t = "4-10"
    elif t < 60:
        t = "10-60"
    elif t < 60 * 24:
        t = "60-1d"
    else:
        t = ">=1d"
    return (l, p, t, c)


def sumid(sr):
    return (sr[0], sr[1], sr[2])


def for_each_task(ts):
    global non

    ts = sorted(ts, key=lambda x: x["time"])

    in_exec = False
    exec_start = None
    exec_tot = 0
    priority = -1
    l = len(ts)
    last_term = -1

    for i,t in enumerate(ts):
        if t["priority"] is not -1 and priority is -1:
            priority = t["priority"]
        if t["type"] >= 4 and t["type"] <= 8:
            last_term = t["type"]
        if in_exec and (t["type"] == 1 or (t["type"] >= 4 and t["type"] <= 8)):
            exec_tot += t["time"] - exec_start
            in_exec = False
        if (not in_exec) and (t["type"] == 3):
            exec_start = t["time"]
            in_exec = True

    return sumrow(last_term, priority, exec_tot, l)


def cleanup(x):
    return {
        "time": int(x.time),
        "type": 0 if x.type is None else int(x.type),
        "id": x.collection_id + "-" + x.instance_index,
        "priority": -1 if x.priority is None else int(x.priority)
    }

def sum_rows(xs):
    csum = 0
    for x in xs:
        csum += x[3]
    return csum

df2 = df.rdd \
    .filter(lambda x: x.collection_type is None or x.collection_type == 0) \
    .filter(lambda x: x.time is not None and x.instance_index is not None and x.collection_id is not None) \
    .map(cleanup) \
    .groupBy(lambda x: x["id"]) \
    .mapValues(for_each_task) \
    .map(lambda x: x[1]) \
    .groupBy(lambda x: sumid(x)) \
    .mapValues(sum_rows) \
    .map(lambda x: str(x[0][0]) + "," + str(x[0][1]) + "," + str(x[0][2]) + "," + str(x[1])) \
    .coalesce(1) \
    .saveAsTextFile(cluster + "_priority_exectime")

# vim: set ts=4 sw=4 et tw=80: