欢迎光临散文网 会员登陆 & 注册

yolov5修改dml运行代码

2023-09-14 14:46 作者:pzqking  | 我要投稿

def select_device(device='', batch_size=0, newline=True):

    # device = None or 'cpu' or 0 or '0' or '0,1,2,3'

    s = f'YOLOv5 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '

    device = str(device).strip().lower().replace('cuda:', '').replace('none', '')  # to string, 'cuda:0' to '0'

    dml = device == 'dml'#设置使用dml

    cpu = device == 'cpu'

    mps = device == 'mps'  # Apple Metal Performance Shaders (MPS)

    if cpu or mps or dml:

        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'  # force torch.cuda.is_available() = False

    elif device:  # non-cpu device requested

        os.environ['CUDA_VISIBLE_DEVICES'] = device  # set environment variable - must be before assert is_available()

        assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \

            f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"

    if dml and torch_directml.is_available():

        devices=torch_directml.device(0)#启用0号dml设备,在这可以更换使用的设备

        n=0

        s+=r"dml:"+str(torch_directml.device_name(0))

        arg=torch_directml.device(0)

    elif not cpu and not mps and torch.cuda.is_available():  # prefer GPU if available

        devices = device.split(',') if device else '0'  # range(torch.cuda.device_count())  # i.e. 0,1,6,7

        n = len(devices)  # device count

        if n > 1 and batch_size > 0:  # check batch_size is divisible by device_count

            assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'

        space = ' ' * (len(s) + 1)

        for i, d in enumerate(devices):

            p = torch.cuda.get_device_properties(i)

            s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n"  # bytes to MB

        arg = 'cuda:0'

    elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available():  # prefer MPS if available

        s += 'MPS\n'

        arg = 'mps'

    else:  # revert to CPU

        s += 'CPU\n'

        arg ="cpu"


    if not newline:

        s = s.rstrip()

    LOGGER.info(s)

    #print(torch.device(arg))

    return torch.device(arg)

#替换至utils文件夹下torch_utils.py中的select_device函数即可使用,

yolov5修改dml运行代码的评论 (共 条)

分享到微博请遵守国家法律