import datetime
import bz2
import h5py
import numpy as np
import pandas as pd
import scipy.sparse as sps
from scipy.sparse import isspmatrix_csr,csr_matrix,csc_matrix,issparse,isspmatrix_csr
import math
import os
def generate_dee2_h5(data_file_name, metadata_file, gene_info_file_name, h5_gsm_name):
    dt = h5py.special_dtype(vlen = str)
    now = datetime.datetime.now()
    print("make gene index...")
    gene_info=pd.read_csv(filepath_or_buffer = gene_info_file_name, sep = "\t")
    gene_ind= {gene:ind for ind, gene in enumerate(gene_info["GeneID"])}
    n_genes=len(gene_ind)
    print("gene quantity: " + str(n_genes))
    print("read meta file...")
    iter_meta = pd.read_csv(filepath_or_buffer = metadata_file, sep = "\t", iterator = True, chunksize = 1000)
    meta_df = pd.concat([chunk for chunk in iter_meta])
    geo_acc_name = "sample_alias"
    geo_gse_name = "GEO_series"
    meta_df[geo_gse_name] = meta_df[geo_gse_name].astype(str)
    true_srr = len(meta_df)
    print("real SRR quantity: " + str(true_srr))
    meta_df = meta_df[meta_df[geo_acc_name].str.startswith('GSM')]
    n_srr = len(meta_df)
    print("SRRs with gse: " + str(n_srr))
    print("make SRR and gsm indexes...")
    srr_to_gsm= dict(zip(meta_df["SRR_accession"], meta_df[geo_acc_name]))
    srr_ind= {srr:ind for ind, srr in enumerate(srr_to_gsm.keys())}
    gsm_ind= {gsm:ind for ind, gsm in enumerate(pd.unique(meta_df[geo_acc_name]))}
    col_list = [geo_acc_name,"SRR_accession", "experiment_instrument_model", "experiment_library_selection", "experiment_library_source", "experiment_library_strategy", "sample_scientific_name", geo_gse_name, "QC_summary",  "Experiment_title"]
    meta_df = meta_df[col_list]
    n_gsm = len(gsm_ind)
    print("GSMs quantity:" + str(n_gsm) )
    gsm_from_srr_matrix= sps.lil_matrix((n_gsm, n_srr))
    for srr_key in srr_to_gsm:
        gsm_from_srr_matrix[gsm_ind[srr_to_gsm[srr_key]], srr_ind[srr_key]] = 1
    print("start h5...")
    with h5py.File(h5_gsm_name, 'w') as h5_gse:
        print("create file and write info...")
        h5_gse.create_dataset('/meta/info/version', data = "10")
        h5_gse.create_dataset('/meta/info/creation_date', data = now.strftime("%Y-%m-%d"))
        h5_gse.create_dataset('/meta/info/database', data = "dee2")
        h5_gse.create_dataset('/meta/info/source_file', data = os.path.basename(data_file_name))
        h5_gse.create_dataset('/meta/info/author', data = "Maksim Kleverov")
        h5_gse.create_dataset('/meta/info/contact', data = "klevermx@gmail.com")
        exp_data=h5_gse.create_dataset("/data/expression", (n_genes,n_gsm), dtype = 'i4',chunks = (200,200), compression = "gzip", compression_opts = 7)
        srr_per_time = 200
        iter_data = pd.read_csv(filepath_or_buffer = data_file_name, sep = "\t", iterator = True, chunksize = srr_per_time*n_genes, names = ["srr", "gene", "se"])
        proc_srr = 0
        gsm_from_srr_matrix = csr_matrix(gsm_from_srr_matrix)
        print("start process expression...")
        bad_srrs = []
        old_proc = 0
        for chunk in iter_data:
            count_genes = chunk["gene"].groupby(chunk['srr'], as_index = True).nunique()
            srr_count = len(count_genes)
            proc_srr = proc_srr + srr_count       
            count_genes = count_genes[count_genes.index.isin(srr_ind)]
            local_srr_ind = count_genes[count_genes == n_genes].index
            diff = count_genes.index.difference(local_srr_ind)
            bad_srrs = [*bad_srrs, *diff]
            if len(local_srr_ind):
                local_srr_ind = pd.DataFrame(data = range(0, len(local_srr_ind)), index = local_srr_ind, columns = ["local_pos"])
                chunk = chunk[chunk["srr"].isin(local_srr_ind.index)]
                local_srr_ind['global_pos'] = local_srr_ind.index.map(srr_ind)
                #row_nums = [local_srr_ind["local_pos"][srr] for srr in chunk["srr"]]
                #col_nums = [gene_ind[gene] for gene in chunk["gene"]]
                #raw_matrix = csr_matrix((chunk["se"], (row_nums,col_nums)))
                raw_matrix = csr_matrix(chunk["se"].values.astype(int).reshape(len(local_srr_ind), n_genes))
                A = gsm_from_srr_matrix[:, local_srr_ind["global_pos"]] #size = all_gsm x needed_srr
                gsm_mask = np.unique(A.nonzero()[0]) # find needed gsm (rows in matrix above)
                A = A.tocsr() 
                A = A.dot(raw_matrix)
                A = A.transpose()
                exp_data[:,gsm_mask] +=  A[:, gsm_mask] 
            if proc_srr-old_proc > 500:
                old_proc = proc_srr
                print(proc_srr/true_srr)
        print("publish meta...")
        print("start grouping...")

        for srr in bad_srrs:
            for col in col_list:
                if (col not in [geo_acc_name, geo_gse_name]):
                    meta_df.loc[meta_df["SRR_accession"]==srr, col] = ""
        meta_df = meta_df.fillna("")
        meta_df = meta_df.groupby(geo_acc_name).agg([lambda x: ';'.join(x)])
        print("groupped")
        meta_df = meta_df.loc[gsm_ind.keys()]
        print("reordered")
        gene_info = pd.read_csv(filepath_or_buffer = gene_info_file_name, sep="\t", index_col = "GeneID")
        gene_info = gene_info.fillna("")
        dt = h5py.special_dtype(vlen = str)
        ensem = np.array(list(gene_ind.keys()), dtype = dt)
        gene_symbol = gene_info.loc[ensem]["GeneSymbol"]
        h5_gse.create_dataset('/meta/genes/ensembl_gene_id', data = ensem, chunks = (len(gene_ind),), compression = "gzip", compression_opts = 7)
        h5_gse.create_dataset('/meta/genes/gene_symbol', data = np.array(gene_symbol, dtype = dt), chunks = (len(gene_ind),), compression = "gzip", compression_opts = 7)
        h5_gse.create_dataset('/meta/info/description', data = "Contain all SRRs from dee2 for which sample_geo_accession starts with gse. Star estimates was groupped by sample_geo_accession and summed up")
        h5_gse.create_dataset('/meta/genes/genes', data = np.array(gene_symbol, dtype = dt), chunks =(len(gene_ind),), compression="gzip", compression_opts = 7)
        h5_gse.create_dataset('/meta/samples/geo_accession', data = np.array(list(gsm_ind.keys()), dtype = dt), chunks = (2000,), compression = "gzip", compression_opts = 5)
        h5_gse.create_dataset('/meta/samples/SRR_accession', data = np.reshape(np.array(meta_df["SRR_accession"]), -1), dtype = dt,chunks = (2000,), compression = "gzip", compression_opts = 5)
        h5_gse.create_dataset('/meta/samples/instrument_model', data = np.reshape(np.array(meta_df["experiment_instrument_model"]), -1), dtype = dt, chunks = (2000,), compression = "gzip", compression_opts = 5)
        h5_gse.create_dataset('/meta/samples/library_selection', data = np.reshape(np.array(meta_df["experiment_library_selection"]), -1), dtype = dt, chunks = (2000,), compression = "gzip", compression_opts = 5)
        h5_gse.create_dataset('/meta/samples/library_source', data = np.reshape(np.array(meta_df["experiment_library_source"]), -1), dtype = dt, chunks = (2000,), compression = "gzip", compression_opts = 5)
        h5_gse.create_dataset('/meta/samples/library_strategy', data = np.reshape(np.array(meta_df["experiment_library_strategy"]), -1), dtype = dt, chunks = (2000,), compression = "gzip", compression_opts = 5)
        h5_gse.create_dataset('/meta/samples/organism_ch1', data=np.reshape(np.array(meta_df["sample_scientific_name"]), -1), dtype = dt, chunks = (2000,), compression = "gzip", compression_opts = 5)
        h5_gse.create_dataset('/meta/samples/series_id', data = np.reshape(np.array(meta_df[geo_gse_name]), -1), dtype = dt, chunks = (2000,), compression = "gzip", compression_opts = 5)
        h5_gse.create_dataset('/meta/samples/quality', data = np.reshape(np.array(meta_df["QC_summary"]), -1), dtype = dt, chunks = (2000,), compression = "gzip", compression_opts = 5)
        h5_gse.create_dataset('/meta/samples/title', data = np.reshape(np.array(meta_df["Experiment_title"]), -1), dtype = dt, chunks = (2000,), compression = "gzip", compression_opts = 5)
        h5_gse.create_dataset('/meta/samples/type', data = np.full(len(meta_df["Experiment_title"]), "SRA", dtype = dt), chunks = (2000,), compression = "gzip", compression_opts = 5)
        
        h5_gse.flush()
    print(1)

    

generate_dee2_h5(data_file_name = snakemake.input.data_file_name, metadata_file = snakemake.input.meta_file, gene_info_file_name = snakemake.input.gene_info, h5_gsm_name = snakemake.output.dee2h5)