赞
踩
cargo new myapp
cd myapp
cargo add --git https://github.com/huggingface/candle.git candle-core
cargo build # 测试,或执行 cargo ckeck
use candle_core::{Device, Tensor};
fn main() -> Result<(), Box<dyn std::error::Error>> {
let device = Device::Cpu;
let a = Tensor::randn(0f32, 1., (2, 3), &device)?;
let b = Tensor::randn(0f32, 1., (3, 4), &device)?;
let c = a.matmul(&b)?;
println!("{c}");
Ok(())
}
~/myrust$ cargo new myapp Created binary (application) `myapp` package ~/myrust$ cd myapp ~/myrust/myapp$ cargo add --git https://github.com/huggingface/candle.git candle-core Updating git repository `https://github.com/huggingface/candle.git` Updating git submodule `https://github.com/NVIDIA/cutlass.git` Adding candle-core (git) to dependencies. Features: - accelerate - cuda - cudarc - cudnn - metal - mkl Updating git repository `https://github.com/huggingface/candle.git` Updating crates.io index ~/myrust/myapp$ cargo build Downloaded serde_derive v1.0.195 Downloaded either v1.9.0 Downloaded autocfg v1.1.0 Downloaded zerofrom v0.1.3 Downloaded zerofrom-derive v0.1.3 Downloaded synstructure v0.13.0 Downloaded crossbeam-deque v0.8.5 Downloaded yoke-derive v0.7.3 Downloaded half v2.3.1 Downloaded bytemuck v1.14.1 Downloaded rand_core v0.6.4 Downloaded paste v1.0.14 Downloaded proc-macro2 v1.0.78 Downloaded itoa v1.0.10 Downloaded memmap2 v0.9.4 Downloaded syn v2.0.48 Downloaded crossbeam-epoch v0.9.18 Downloaded cfg-if v1.0.0 Downloaded bitflags v1.3.2 Downloaded num_cpus v1.16.0 Downloaded gemm-f32 v0.17.0 Downloaded reborrow v0.5.5 Downloaded stable_deref_trait v1.2.0 Downloaded rayon-core v1.12.1 Downloaded seq-macro v0.3.5 Downloaded thiserror-impl v1.0.56 Downloaded dyn-stack v0.10.0 Downloaded thiserror v1.0.56 Downloaded unicode-xid v0.2.4 Downloaded rand_chacha v0.3.1 Downloaded ppv-lite86 v0.2.17 Downloaded bytemuck_derive v1.5.0 Downloaded getrandom v0.2.12 Downloaded once_cell v1.19.0 Downloaded unicode-ident v1.0.12 Downloaded byteorder v1.5.0 Downloaded crc32fast v1.3.2 Downloaded num-complex v0.4.4 Downloaded gemm-common v0.17.0 Downloaded crossbeam-utils v0.8.19 Downloaded quote v1.0.35 Downloaded ryu v1.0.16 Downloaded num-traits v0.2.17 Downloaded zip v0.6.6 Downloaded rand_distr v0.4.3 Downloaded serde v1.0.195 Downloaded rand v0.8.5 Downloaded raw-cpuid v10.7.0 Downloaded libm v0.2.8 Downloaded serde_json v1.0.111 Downloaded rayon v1.8.1 Downloaded libc v0.2.152 Downloaded gemm-c64 v0.17.0 Downloaded gemm-c32 v0.17.0 Downloaded safetensors v0.4.2 Downloaded gemm-f64 v0.17.0 Downloaded gemm v0.17.0 Downloaded gemm-f16 v0.17.0 Downloaded yoke v0.7.3 Downloaded pulp v0.18.6 Downloaded 60 crates (3.1 MB) in 14.91s Compiling proc-macro2 v1.0.78 Compiling unicode-ident v1.0.12 Compiling libc v0.2.152 Compiling cfg-if v1.0.0 Compiling libm v0.2.8 Compiling autocfg v1.1.0 Compiling crossbeam-utils v0.8.19 Compiling ppv-lite86 v0.2.17 Compiling rayon-core v1.12.1 Compiling reborrow v0.5.5 Compiling paste v1.0.14 Compiling either v1.9.0 Compiling bitflags v1.3.2 Compiling seq-macro v0.3.5 Compiling once_cell v1.19.0 Compiling unicode-xid v0.2.4 Compiling raw-cpuid v10.7.0 Compiling serde v1.0.195 Compiling crc32fast v1.3.2 Compiling serde_json v1.0.111 Compiling stable_deref_trait v1.2.0 Compiling itoa v1.0.10 Compiling ryu v1.0.16 Compiling thiserror v1.0.56 Compiling byteorder v1.5.0 Compiling num-traits v0.2.17 Compiling zip v0.6.6 Compiling crossbeam-epoch v0.9.18 Compiling quote v1.0.35 Compiling syn v2.0.48 Compiling crossbeam-deque v0.8.5 Compiling getrandom v0.2.12 Compiling memmap2 v0.9.4 Compiling num_cpus v1.16.0 Compiling rand_core v0.6.4 Compiling rand_chacha v0.3.1 Compiling rayon v1.8.1 Compiling rand v0.8.5 Compiling rand_distr v0.4.3 Compiling synstructure v0.13.0 Compiling bytemuck_derive v1.5.0 Compiling serde_derive v1.0.195 Compiling zerofrom-derive v0.1.3 Compiling thiserror-impl v1.0.56 Compiling yoke-derive v0.7.3 Compiling bytemuck v1.14.1 Compiling num-complex v0.4.4 Compiling dyn-stack v0.10.0 Compiling half v2.3.1 Compiling zerofrom v0.1.3 Compiling yoke v0.7.3 Compiling pulp v0.18.6 Compiling gemm-common v0.17.0 Compiling gemm-f32 v0.17.0 Compiling gemm-c64 v0.17.0 Compiling gemm-f64 v0.17.0 Compiling gemm-c32 v0.17.0 Compiling gemm-f16 v0.17.0 Compiling gemm v0.17.0 Compiling safetensors v0.4.2 Compiling candle-core v0.3.3 (https://github.com/huggingface/candle.git#fd7c8565) Compiling myapp v0.1.0 (/home/pdd/myrust/myapp) Finished dev [unoptimized + debuginfo] target(s) in 32.90s
git clone https://github.com/RileySeaburg/candle_test.git
Cargo.toml 文件[package]
name = "candle_test"
version = "0.1.0"
edition = "2021" # Rust 版本
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.2.1", features = ["cuda"] }
# `candle-core`:项目依赖的包的名称。`git` 字段指定了包的源代码仓库地址。`version` 字段指定了使用的包的版本。`features` 字段是一个数组,指定了启用的功能。在这里,启用了 "cuda" 功能。
# 可以通过以下命令添加,取消可注释掉"cuda",再cargo build
# cargo add --git https://github.com/huggingface/candle.git candle-core
# cargo add candle-core --features cuda
use candle_core::{DType, Device, Result, Tensor}; // 定义一个模型结构体 struct Model { first: Tensor, second: Tensor, } impl Model { // 定义模型的前向传播方法 fn forward(&self, image: &Tensor) -> Result<Tensor> { let x = image.matmul(&self.first)?; // 输入乘以第一层权重 let x = x.relu()?; // 使用 ReLU 激活函数 x.matmul(&self.second) // 结果乘以第二层权重 } } fn main() -> Result<()> { // 初始化设备,如果 GPU 可用则使用 GPU,否则使用 CPU let device = match Device::new_cuda(0) { Ok(device) => device, Err(_) => Device::Cpu, }; // 创建模型的第一层和第二层权重张量 let first = Tensor::zeros((784, 100), DType::F32, &device) .unwrap() .contiguous()?; let second = Tensor::zeros((100, 10), DType::F32, &device) .unwrap() .contiguous()?; // 初始化模型 let model = Model { first, second }; // 创建一个用于测试的虚拟图像张量 let dummy_image = Tensor::zeros((1, 784), DType::F32, &device) .unwrap() .contiguous()?; // 调用模型的前向传播方法获取预测结果 let digit = model.forward(&dummy_image)?; // 打印预测结果 println!("Digit {digit:?} digit"); Ok(()) }

// Result定义在/home/pdd/.cargo/git/checkouts/candle-0c2b4fa9e5801351/e8e3375/candle-core/src/error.rs
pub type Result<T> = std::result::Result<T, Error>; // 定义了一个 `Result` 类型,这是一个 `Result<T, Error>` 类型的别名。其中 `T` 是成功时的返回类型,而 `Error` 是失败时的错误类型。
// Ok(()) 定义在 /home/pdd/.rustup/toolchains/stable-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/core/src/result.rs
// 这是 Rust 标准库中的 `Result` 公共的枚举类型,它有两个泛型参数 `T` 和 `E`。`T` 代表成功时返回的值的类型,`E` 代表错误时返回的错误类型。
// #[]是属性(attribute),提供额外信息
pub enum Result<T, E> {
/// Contains the success value
#[lang = "Ok"]
#[stable(feature = "rust1", since = "1.0.0")]
Ok(#[stable(feature = "rust1", since = "1.0.0")] T),// `Ok(T)`: 这是 `Result` 枚举的一个变体,用于表示成功的情况
// (): 是 Rust 中的单元类型(unit type),类似于其他语言中的 void。
/// Contains the error value
#[lang = "Err"]
#[stable(feature = "rust1", since = "1.0.0")]
Err(#[stable(feature = "rust1", since = "1.0.0")] E),// `Err(E)`: 这是 `Result` 枚举的另一个变体,用于表示错误的情况。
}
? 符号用于处理 Result 或 Option 类型的返回值。这个符号的作用是将可能的错误或 None 值快速传播到调用链的最上层,使得代码更加简洁和易读。fn forward(&self, image: &Tensor) -> Result<Tensor> {
let x = image.matmul(&self.first)?; // 如果matmul返回Err,则整个forward函数返回Err
let x = x.relu()?; // 如果relu返回Err,则整个forward函数返回Err
x.matmul(&self.second) // 如果matmul返回Err,则整个forward函数返回Err;否则返回Ok(Tensor)
}
函数体:函数体是一个块表达式,其值是最后一个表达式的值。
fn add(x: i32, y: i32) -> i32 {
x + y // 表达式
}
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。