diff --git a/main.py b/main.py index c2321086..05eb31c7 100644 --- a/main.py +++ b/main.py @@ -71,6 +71,7 @@ if os.name == "nt": if __name__ == "__main__": if args.cuda_device is not None: os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) + os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device) logging.info("Set cuda device to: {}".format(args.cuda_device)) if args.deterministic: