赞
踩
Windows 11已经支持使用directml加速 pytorch了。
2024.5.24 更新,新的torch-directml 包已经发布,完美支持常用算法。
-----------------------------------------------------------------------------------------------------
2021,11,16更新: directml-pytorch已经推出:
详细教程:(4条消息) Windows下用amd显卡训练 : Pytorch-directml 重大升级,改为pytorch插件形式,兼容更好_znsoft的博客-CSDN博客_amd显卡 pytorch
-----------------------------------------------------------------------------
以 下为旧内容,依然适用,但是不推荐了。看横线以上的。
官方训练原理解释: ONNX Runtime Training Technical Deep Dive - Microsoft Tech Community
检查 支持的设备
- import onnxruntime as ort
- ort.get_device()
ONNX运行时(ORT)能够通过优化的后端训练现有的PyTorch模型。为此,我们为pythorch引入了一个pythorch API,称为ORTTrainer,可用于将pythorch模型的训练后端(实例torch.nn.Module
)切换到orttrainer
。这需要对trainer代码进行一些更改,比如替换PyTorch优化器,还可以选择设置标志来启用其他特性,比如mixed-precisiontraining。下面是一个将ONNX运行时培训集成到PyTorchpre-training脚本中的示例代码片段:
注:目前的API是实验性的,预计在不久的将来会有重大变化。我们的目标是改进接口,以提供与Pythorch训练的无缝集成,这需要对用户的训练代码进行最小的更改。
-
-
- import torch
- ...
- import onnxruntime
- from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer
-
- # Model definition
- class Net(torch.nn.Module):
- def __init__(self, D_in, H, D_out):
- ...
- def forward(self, x):
- ...
- model = Net(D_in, H, H_out)
- criterion = torch.nn.Functional.cross_entropy
- description = ModelDescription(...)
- optimizer = 'SGDOptimizer'
- trainer = ORTTrainer(model, criterion, description, optimizer, ...)
- # Training Loop
- for t in range(1000):
- # forward + backward + weight update
- loss, y_pred = trainer.train_step(x, y, learning_rate)
- ...

Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。