/* -*- Mode: C; c-basic-offset:4 ; -*- */
/* Copyright (c) 2001-2013, The Ohio State University. All rights
 * reserved.
 *
 * This file is part of the MVAPICH2 software package developed by the
 * team members of The Ohio State University's Network-Based Computing
 * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
 *
 * For detailed copyright and licensing information, please refer to the
 * copyright file COPYRIGHT in the top level MVAPICH2 directory.
 */
/*
 *
 *  (C) 2001 by Argonne National Laboratory.
 *      See COPYRIGHT in top-level directory.
 */

#include "mpiimpl.h"

#ifdef _ENABLE_CUDA_
#if defined(_OSU_MVAPICH_) || defined(_OSU_PSM_)
#include "datatype.h"
#include "coll_shmem.h"
#include "unistd.h"

extern void *mv2_cuda_allgather_store_buf;
extern int mv2_cuda_allgather_store_buf_size;
extern cudaEvent_t *mv2_cuda_sync_event;

#undef FUNCNAME
#define FUNCNAME MPIR_Allgather_cuda_intra_MV2
#undef FCNAME
#define FCNAME MPIU_QUOTE(FUNCNAME)
int MPIR_Allgather_cuda_intra_MV2(const void *sendbuf,
                             int sendcount,
                             MPI_Datatype sendtype,
                             void *recvbuf,
                             int recvcount,
                             MPI_Datatype recvtype,
                             MPID_Comm * comm_ptr, int *errflag)
{
    int comm_size, rank;
    int mpi_errno = MPI_SUCCESS;
    int mpi_errno_ret = MPI_SUCCESS;
    MPI_Aint recvtype_extent = 0;
    int j, i;
    int curr_cnt, dst, left, right, jnext;
    MPI_Comm comm;
    int mask, dst_tree_root, my_tree_root,
        send_offset, recv_offset;
    int comm_size_is_pof2;
    MPI_Status status;
    int page_size = 0;
    int result, max_size;
    MPI_Request recv_req;
    MPI_Request send_req;
    cudaError_t cudaerr;

    if (((sendcount == 0) && (sendbuf != MPI_IN_PLACE)) || (recvcount == 0)) {
        return MPI_SUCCESS;
    }

    comm = comm_ptr->handle;
    comm_size = comm_ptr->local_size;
    rank = comm_ptr->rank;
    comm_size_is_pof2 = comm_ptr->ch.is_pof2;

    MPID_Datatype_get_extent_macro(recvtype, recvtype_extent);

    /* check if multiple threads are calling this collective function */
    MPIDU_ERR_CHECK_MULTIPLE_THREADS_ENTER(comm_ptr);

    /*Creating Store Buffer*/
    page_size = getpagesize();

    max_size = mv2_cuda_allgather_store_buf_size < recvcount * comm_size * recvtype_extent ? 
            recvcount * comm_size * recvtype_extent : mv2_cuda_allgather_store_buf_size;

    if (mv2_cuda_allgather_store_buf_size < max_size || !mv2_cuda_allgather_store_buf){
        if(mv2_cuda_allgather_store_buf){
            ibv_cuda_unregister(mv2_cuda_allgather_store_buf);
            free (mv2_cuda_allgather_store_buf);
        } 
        result = posix_memalign(&mv2_cuda_allgather_store_buf, page_size, max_size);
        if ((result!=0) || (NULL == mv2_cuda_allgather_store_buf)) {
            mpi_errno = MPIR_Err_create_code( MPI_SUCCESS, MPI_ERR_OTHER,
                    FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", "%s: %s",
                    "posix_memalign", strerror(errno));
            MPIU_ERR_POP (mpi_errno);
        }
        ibv_cuda_register(mv2_cuda_allgather_store_buf, max_size);
        mv2_cuda_allgather_store_buf_size = max_size;
    }

    /*Creating event to synchronize at end*/
    if (!mv2_cuda_sync_event) {
        mv2_cuda_sync_event = (cudaEvent_t *) MPIU_Malloc(sizeof(cudaEvent_t));
        cudaerr = cudaEventCreateWithFlags(mv2_cuda_sync_event, cudaEventDisableTiming);
        if(cudaerr != cudaSuccess) {
            mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME,
                    __LINE__, MPI_ERR_OTHER, "**nomem", 0);
            return mpi_errno;
        }
    }

    if (recvcount*recvtype_extent > rdma_cuda_allgather_rd_limit*comm_size || 
            !comm_size_is_pof2) { // RING
            
        /* First, load the "local" version in the recvbuf. */
        if (sendbuf != MPI_IN_PLACE) {
            mpi_errno = MPIR_Localcopy(sendbuf, sendcount, sendtype,
                                       ((char *) recvbuf +
                                        rank * recvcount * recvtype_extent),
                                       recvcount, recvtype);
            if (mpi_errno) {
                MPIU_ERR_POP(mpi_errno);
            }
        }

        /*Starting the RING stages*/
        left  = (comm_size + rank - 1) % comm_size;
        right = (rank + 1) % comm_size;
        j     = rank;
        jnext = left;

        mpi_errno = MPIC_Irecv( ((char *)mv2_cuda_allgather_store_buf + jnext*recvcount*recvtype_extent),
                                recvcount*recvtype_extent,
                                MPI_BYTE,
                                left,
                                MPIR_ALLGATHER_TAG,
                                comm,
                                &recv_req );
        mpi_errno = MPIC_Isend(((char *)recvbuf + j*recvcount*recvtype_extent),
                            recvcount*recvtype_extent,
                            MPI_BYTE,
                            right,
                            MPIR_ALLGATHER_TAG,
                            comm,
                            &send_req );
        mpi_errno = MPIC_Waitall_ft(1, &recv_req, &status, errflag);
	    if (mpi_errno) {
            /* for communication errors, just record the error but continue */
            *errflag = TRUE;
            MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
            MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
	    }

        MPIU_Memcpy_CUDA_Async((void *)((char *)recvbuf + jnext*recvcount*recvtype_extent),
                (void *)((char *)mv2_cuda_allgather_store_buf + jnext*recvcount*recvtype_extent),
                recvcount*recvtype_extent,
                cudaMemcpyHostToDevice,
                0 );

        mpi_errno = MPIC_Waitall_ft(1, &send_req, &status, errflag);
	    if (mpi_errno) {
            /* for communication errors, just record the error but continue */
            *errflag = TRUE;
            MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
            MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
	    }

        j       = jnext;
        jnext = (comm_size + jnext - 1) % comm_size;

        /*Intermediate steps of communication*/
        for (i=2; i<comm_size-1; i++) {
            mpi_errno = MPIC_Irecv( ((char *)mv2_cuda_allgather_store_buf + jnext*recvcount*recvtype_extent),
                                    recvcount,
                                    recvtype,
                                    left,
                                    MPIR_ALLGATHER_TAG,
                                    comm,
                                    &recv_req );
            mpi_errno = MPIC_Isend(((char *)mv2_cuda_allgather_store_buf + j*recvcount*recvtype_extent),
                                    recvcount,
                                    recvtype,
                                    right,
                                    MPIR_ALLGATHER_TAG,
                                    comm,
                                    &send_req );
            mpi_errno = MPIC_Waitall_ft(1, &recv_req, &status, errflag);
	        if (mpi_errno) {
                /* for communication errors, just record the error but continue */
                *errflag = TRUE;
                MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
                MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
	        }
                    
            MPIU_Memcpy_CUDA_Async((void *)((char *)recvbuf + jnext*recvcount*recvtype_extent),
                    (void *)((char *)mv2_cuda_allgather_store_buf + jnext*recvcount*recvtype_extent),
                    recvcount*recvtype_extent,
                    cudaMemcpyHostToDevice,
                    0 );

            mpi_errno = MPIC_Waitall_ft(1, &send_req, &status, errflag);
	        if (mpi_errno) {
                /* for communication errors, just record the error but continue */
                *errflag = TRUE;
                MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
                MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
	        }

            j	    = jnext;
            jnext = (comm_size + jnext - 1) % comm_size;
        }

        /*Last stage of communication - copy directly to device*/
        if ( i < comm_size ){
            mpi_errno = MPIC_Irecv( ((char *)recvbuf + jnext*recvcount*recvtype_extent),
                    recvcount,
                    recvtype,
                    left,
                    MPIR_ALLGATHER_TAG,
                    comm,
                    &recv_req );
            mpi_errno = MPIC_Isend(((char *)mv2_cuda_allgather_store_buf + j*recvcount*recvtype_extent),
                    recvcount,
                    recvtype,
                    right,
                    MPIR_ALLGATHER_TAG,
                    comm,
                    &send_req );
            mpi_errno = MPIC_Waitall_ft(1, &recv_req, &status, errflag);
	        if (mpi_errno) {
                /* for communication errors, just record the error but continue */
                *errflag = TRUE;
                MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
                MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
	        }
            mpi_errno = MPIC_Waitall_ft(1, &send_req, &status, errflag);
	        if (mpi_errno) {
                /* for communication errors, just record the error but continue */
                *errflag = TRUE;
                MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
                MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
	        }

        }

    } else { /*Recursive Doubling*/
            MPI_Request recv_req;
            MPI_Request send_req;

            if (sendbuf != MPI_IN_PLACE) {
                mpi_errno = MPIR_Localcopy(sendbuf, sendcount, sendtype,
                                           ((char *) recvbuf +
                                            rank * recvcount * recvtype_extent),
                                           recvcount, recvtype);
                if (mpi_errno) {
                    MPIU_ERR_POP(mpi_errno);
                }
            }
            
            curr_cnt = recvcount;
            
            mask = 0x1;
            i = 0;
            

            dst = rank ^ mask;
            dst_tree_root = dst >> i;
            dst_tree_root <<= i;
            
            my_tree_root = rank >> i;
            my_tree_root <<= i;

		/* F: saving an MPI_Aint into an int */
            send_offset = my_tree_root * recvcount * recvtype_extent;
            recv_offset = dst_tree_root * recvcount * recvtype_extent;
            
            if (dst < comm_size) {
                MPIU_Memcpy_CUDA((void*)((char *)mv2_cuda_allgather_store_buf + rank*recvcount*recvtype_extent), 
                                        (void*)((char *)recvbuf + rank*recvcount*recvtype_extent), 
                                        recvcount * recvtype_extent, 
                                        cudaMemcpyDeviceToHost);

                mpi_errno = MPIC_Irecv( ((char *)mv2_cuda_allgather_store_buf + recv_offset),
                                        (mask)*recvcount, 
                                        recvtype, 
                                        dst, 
                                        MPIR_ALLGATHER_TAG,
                                        comm,
                                        &recv_req );
                mpi_errno = MPIC_Isend(((char *)mv2_cuda_allgather_store_buf + send_offset),
                                        curr_cnt, 
                                        recvtype, 
                                        dst, 
                                        MPIR_ALLGATHER_TAG,
                                        comm,
                                        &send_req );

                mpi_errno = MPIC_Waitall_ft(1, &recv_req, &status, errflag);
                if (mpi_errno) {
                           /* for communication errors, just record the error but continue */
                            *errflag = TRUE;
                            MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
                            MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
                }

                MPIU_Memcpy_CUDA_Async((void*)((char *)recvbuf + recv_offset),
                                    (void*)((char *)mv2_cuda_allgather_store_buf + recv_offset),
                                    (mask)*recvcount*recvtype_extent,
                                    cudaMemcpyHostToDevice,
                                    0 );

                mpi_errno = MPIC_Waitall_ft(1, &send_req, &status, errflag);
                if (mpi_errno) {
                           /* for communication errors, just record the error but continue */
                            *errflag = TRUE;
                            MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
                            MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
                }

                curr_cnt += mask*recvcount; 
            }

            mask <<= 1;
            i++;

            while (mask < comm_size) {
                dst = rank ^ mask;
                
                /* find offset into send and recv buffers. zero out 
                   the least significant "i" bits of rank and dst to 
                   find root of src and dst subtrees. Use ranks of 
                   roots as index to send from and recv into buffer */ 
                
                dst_tree_root = dst >> i;
                dst_tree_root <<= i;
                
                my_tree_root = rank >> i;
                my_tree_root <<= i;

		/* FIXME: saving an MPI_Aint into an int */
                send_offset = my_tree_root * recvcount * recvtype_extent;
                recv_offset = dst_tree_root * recvcount * recvtype_extent;
                
                if (dst < comm_size) {
                    if (mask == comm_size/2) {
                        mpi_errno = MPIC_Irecv( ((char *)recvbuf + recv_offset),
                                                (mask)*recvcount, 
                                                recvtype, 
                                                dst, 
                                                MPIR_ALLGATHER_TAG,
                                                comm,
                                                &recv_req );
                    } else {
                        mpi_errno = MPIC_Irecv( ((char *)mv2_cuda_allgather_store_buf + recv_offset),
                                                (mask)*recvcount, 
                                                recvtype, 
                                                dst, 
                                                MPIR_ALLGATHER_TAG,
                                                comm,
                                                &recv_req );
                    }
                    mpi_errno = MPIC_Isend(((char *)mv2_cuda_allgather_store_buf + send_offset),
                                            curr_cnt, 
                                            recvtype, 
                                            dst, 
                                            MPIR_ALLGATHER_TAG,
                                            comm,
                                            &send_req );
                    mpi_errno = MPIC_Waitall_ft(1, &recv_req, &status, errflag);
                    if (mpi_errno) {
                               /* for communication errors, just record the error but continue */
                                *errflag = TRUE;
                                MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
                                MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
                    }

                    if (mask < comm_size/2) {
                        MPIU_Memcpy_CUDA_Async(((void*) ((char *)recvbuf + recv_offset)),
                                            (void *)((char *)mv2_cuda_allgather_store_buf + recv_offset),
                                            (mask)*recvcount*recvtype_extent,
                                            cudaMemcpyHostToDevice,
                                            0 );
                    }
                    mpi_errno = MPIC_Waitall_ft(1, &send_req, &status, errflag);
                    if (mpi_errno) {
                               /* for communication errors, just record the error but continue */
                                *errflag = TRUE;
                                MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
                                MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
                    }
                    
                    curr_cnt += mask*recvcount;
                }
                
                mask <<= 1;
                i++;
            }
    }

    /* wait for the receive copies into the device to complete */
    cudaerr = cudaEventRecord(*mv2_cuda_sync_event, 0);
    if (cudaerr != cudaSuccess) {
        mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME,
                __LINE__, MPI_ERR_OTHER, "**cudaEventRecord", 0);
        return mpi_errno;
    }
    cudaEventSynchronize(*mv2_cuda_sync_event);

    /* check if multiple threads are calling this collective function */
    MPIDU_ERR_CHECK_MULTIPLE_THREADS_EXIT(comm_ptr);

  fn_fail:
    return (mpi_errno);
}
/* end:nested */
#endif /* #if defined(_OSU_MVAPICH_) || defined(_OSU_PSM_) */
#endif /*#ifdef(_ENABLE_CUDA_)*/
