赞
踩
LLMs之Grok-1:checkpoint.py文件解读—加载和恢复机器学习模型检查点的工具(基于JAX库处理多维数组计算+大规模分布式训练+多主机间的数据同步和分片)
目录
checkpoint.py文件解读—加载和恢复机器学习模型检查点的工具(基于JAX库处理多维数组计算+大规模分布式训练+多主机间的数据同步和分片)
checkpoint.py文件解读—加载和恢复机器学习模型检查点的工具(基于JAX库处理多维数组计算+大规模分布式训练+多主机间的数据同步和分片)
源码地址:grok-1/checkpoint.py at main · xai-org/grok-1 · GitHub
加载和恢复机器学习模型检查点的工具(基于JAX库处理多维数组计算+大规模分布式训练+多主机间的数据同步和分片)。它使用JAX库来处理多维数组计算,并涉及多主机间的数据同步。利用共享内存和多线程来加速文件操作,并使用JAX库来处理多维数组计算。它还支持多主机间的数据同步和分片,使得可以在大规模分布式训练中使用。
核心要点:
>> 使用共享内存来加速文件读写操作。
>> 使用多线程并行加载多个张量。
>> 使用JAX库处理多维数组计算。
>> 支持多主机间的数据同步和分片。
load_tensors:加载张量数组。这个函数接受一个张量列表和一个目录,并异步加载这些张量。它使用线程池来并行加载(多维数组)。张量的数据被存储在多个文件中,每个文件对应一个张量的一部分。
fast_unpickle 和 fast_pickle: 这两个函数用于快速加载和保存Python对象,它们使用上述上下文管理器来处理共享内存。分别用于快速地从文件中读取(unpickle)和写入(pickle)对象。它们使用了copy_to_shm和copy_from_shm来加速文件操作。
fast_unpickle:快速解密
fast_pickle:快速加密
restore:从检查点恢复状态。这个函数是主要的恢复函数,它接受检查点路径、状态形状、网格配置和其他参数,并恢复模型状态。
replace_with_load_state:用加载状态替换初始化状态。这个函数将初始状态替换为从检查点加载的状态。这个函数将初始状态中的某些张量替换为加载状态中的对应张量。它使用get_load_path_str来处理路径规则。
copy_to_shm 和 copy_from_shm: 这两个上下文管理器用于在共享内存(/dev/shm/)和磁盘之间复制文件。它们用于加速Python pickle加载过程。这可以加速文件的读写操作,因为共享内存通常比磁盘快。
copy_to_shm:复制到共享内存
copy_from_shm:从共享内存复制
path_tuple_to_string:将路径元组转换为字符串。这个函数将一个由JAX的树结构表示的路径转换为一个字符串。将一个由多个元素组成的路径元组转换为一个字符串。路径元组的每个元素表示数据结构中的一个级别(如字典的键或对象的属性名)。
get_load_path_str:获取加载路径字符串。 这个函数处理加载路径的命名规则,包括重命名和排除。
multihost_utils.host_local_array_to_global_array: 这个函数可能用于在模型并行中转换本地数组为全局数组。
QuantizedWeight8bit:量化权重配置,通常用于减少模型的内存需求。
logger 和 rank_logger: 这两个日志记录器用于记录信息和警告。
logger:日志记录器
rank_logger:rank日志记录器
- # Copyright 2024 X.AI Corp.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- from __future__ import annotations
-
- import contextlib
- import logging
- import math
- import os
- import pickle
- import re
- import shutil
- import sys
- import tempfile
- from concurrent.futures import ThreadPoolExecutor, wait
- from typing import Any, Optional
-
- import jax
- import numpy as np
- from jax.experimental import multihost_utils
-
- from model import QuantizedWeight8bit
-
- logger = logging.getLogger(__name__)
- rank_logger = logging.getLogger("rank")
-
- # Needed for loading the checkpoint with pickle.
- sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit
-
-
- @contextlib.contextmanager
- def copy_to_shm(file: str):
- if file.startswith("/dev/shm/"):
- # Nothing to do, the file is already in shared memory.
- yield file
- return
-
- tmp_dir = "/dev/shm/"
- fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
- try:
- shutil.copyfile(file, tmp_path)
- yield tmp_path
- finally:
- os.remove(tmp_path)
- os.close(fd)
-
-
- @contextlib.contextmanager
- def copy_from_shm(file: str):
- tmp_dir = "/dev/shm/"
- fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
- try:
- yield tmp_path
- shutil.copyfile(tmp_path, file)
- finally:
- os.remove(tmp_path)
- os.close(fd)
-
-
- def fast_unpickle(path: str) -> Any:
- with copy_to_shm(path) as tmp_path:
- with open(tmp_path, "rb") as f:
- return pickle.load(f)
-
-
- def fast_pickle(obj: Any, path: str) -> None:
- with copy_from_shm(path) as tmp_path:
- with open(tmp_path, "wb") as f:
- pickle.dump(obj, f)
-
-
- def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
- """Loads a set of arrays."""
- pool = ThreadPoolExecutor(max_workers=32)
- fs = list()
- num_tensors = 0
- num_replicas = 1
- data_model_shards = math.prod(mesh_config)
- if tensor_indices is None:
- iterator = enumerate(shaped_arrays)
- else:
- iterator = zip(tensor_indices, shaped_arrays)
- for i, t in iterator:
- if (i % num_replicas) == ((jax.process_index() // data_model_shards) % num_replicas):
- idx = (
- jax.process_index() // (num_replicas * data_model_shards) * data_model_shards
- + jax.process_index() % data_model_shards
- )
- fs.append(
- pool.submit(fast_unpickle, os.path.join(directory, f"tensor{i:05d}_{idx:03d}"))
- )
- num_tensors += 1
- else:
- fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype))
- wait(fs)
- return [f.result() for f in fs]
-
-
- def path_tuple_to_string(path: tuple) -> str:
- pieces = []
- for elem in path:
- if isinstance(elem, jax.tree_util.DictKey):
- pieces.append(elem.key)
- elif isinstance(elem, jax.tree_util.GetAttrKey):
- pieces.append(elem.name)
- else:
- assert isinstance(elem, (jax.tree_util.FlattenedIndexKey, jax.tree_util.SequenceKey))
- return "/".join(pieces)
-
-
- def get_load_path_str(
- init_path_str: str,
- load_rename_rules: Optional[list[tuple[str, str]]] = None,
- load_exclude_rules: Optional[list[str]] = None,
- ) -> Optional[str]:
- # Exclusion
- if load_exclude_rules is not None:
- for search_pattern in load_exclude_rules:
- if re.search(search_pattern, init_path_str):
- return None
-
- # Renaming
- load_path_str = init_path_str
- if load_rename_rules is not None:
- for search_pattern, replacement_pattern in load_rename_rules:
- if re.search(search_pattern, load_path_str):
- load_path_str = re.sub(search_pattern, replacement_pattern, load_path_str)
- break
-
- return load_path_str
-
-
- def replace_with_load_state(
- init_state: Any,
- load_state: Any,
- load_rename_rules: Optional[list[tuple[str, str]]] = None,
- load_exclude_rules: Optional[list[str]] = None,
- mesh_config: tuple = (1, 1),
- ) -> Any:
- flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
- flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
- load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
-
- replaced = []
- num_replicas = 1
- data_model_shards = math.prod(mesh_config)
- for i, (init_path, tensor) in enumerate(flatten_init):
- init_path_str = path_tuple_to_string(init_path)
- load_path_str = get_load_path_str(init_path_str, load_rename_rules, load_exclude_rules)
- if load_path_str is None:
- rank_logger.info(f"Excluded from restore: {init_path_str}.")
- replaced.append(tensor)
- elif load_path_str in load_map:
- if load_path_str == init_path_str:
- rank_logger.info(f"Restored from ckpt: {init_path_str}.")
- else:
- rank_logger.info(f"Restored from ckpt: {init_path_str} <-- {load_path_str}.")
- replaced.append(load_map[load_path_str])
- else:
- rank_logger.info(f"Not found in ckpt: {init_path_str}.")
- if (i % num_replicas) == ((jax.process_index() // data_model_shards) % num_replicas):
- replaced.append(tensor)
- else:
- replaced.append(np.zeros_like(tensor))
-
- return jax.tree_util.tree_unflatten(structure_init, replaced)
-
-
- def restore(
- checkpoint_path: str,
- state_shapes: Any,
- mesh,
- between_hosts_config,
- params_only,
- state_sharding,
- init_state: Optional[Any] = None,
- ) -> Any:
- ckpt_path = os.path.join(checkpoint_path, "ckpt-0")
-
- rank_logger.info("Loading checkpoint at {}".format(ckpt_path))
- ckpt_shapes = state_shapes
- ckpt_shapes_with_path, structure = jax.tree_util.tree_flatten_with_path(ckpt_shapes)
-
- ckpt_shapes_flat = [elem[1] for elem in ckpt_shapes_with_path]
- loaded_tensors = load_tensors(ckpt_shapes_flat, ckpt_path, between_hosts_config)
-
- state = jax.tree_util.tree_unflatten(structure, loaded_tensors)
-
- # Sanity check to give a better error message.
- ckpt_keys = set(state.params.keys())
- code_keys = set(state_sharding.params.keys())
-
- if ckpt_keys != code_keys and init_state is None:
- missing_in_ckpt = code_keys - ckpt_keys
- missing_locally = ckpt_keys - code_keys
- raise ValueError(
- "Parameters in the code are not matching checkpoint parameters.\n"
- "Params missing in checkpoint: {}\nParams missing in code: {}".format(
- missing_in_ckpt, missing_locally
- )
- )
- state_sharding = jax.tree_util.tree_map(
- lambda x: jax.sharding.PartitionSpec() if x is None else x,
- state_sharding,
- is_leaf=lambda x: x is None,
- )
- state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding)
- if params_only:
- state = state.params
- return state
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。