# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import shutil
import tempfile

import mmcv.fileio
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info


def broadcast_object_list(object_list, src=0):
    """Broadcasts picklable objects in ``object_list`` to the whole group.

    Note that all objects in ``object_list`` must be picklable in order to be
    broadcasted.

    Args:
        object_list (List[Any]): List of input objects to broadcast.
            Each object must be picklable. Only objects on the src rank will be
            broadcast, but each rank must provide lists of equal sizes.
        src (int): Source rank from which to broadcast ``object_list``.
    """
    my_rank, _ = get_dist_info()

    MAX_LEN = 512
    # 32 is whitespace
    dir_tensor = torch.full((MAX_LEN, ), 32, dtype=torch.uint8, device='cuda')
    object_list_return = list()
    if my_rank == src:
        mmcv.mkdir_or_exist('.dist_broadcast')
        tmpdir = tempfile.mkdtemp(dir='.dist_broadcast')
        mmcv.dump(object_list, osp.join(tmpdir, 'object_list.pkl'))
        tmpdir = torch.tensor(
            bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
        dir_tensor[:len(tmpdir)] = tmpdir

    dist.broadcast(dir_tensor, src)
    tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()

    if my_rank != src:
        object_list_return = mmcv.load(osp.join(tmpdir, 'object_list.pkl'))

    dist.barrier()
    if my_rank == src:
        shutil.rmtree(tmpdir)
        object_list_return = object_list

    return object_list_return