diff --git a/data analysis/cudaregress.py b/data analysis/cudaregress.py index 800c8515..8cdbe141 100644 --- a/data analysis/cudaregress.py +++ b/data analysis/cudaregress.py @@ -38,6 +38,8 @@ __all__ = [ import torch +#set device +device='cuda:0' if torch.cuda.is_available() else 'cpu' #todo: document completely