当前位置:   article > 正文

LLMs之Grok-1:checkpoint.py文件解读—加载和恢复机器学习模型检查点的工具(基于JAX库处理多维数组计算+大规模分布式训练+多主机间的数据同步和分片)

LLMs之Grok-1:checkpoint.py文件解读—加载和恢复机器学习模型检查点的工具(基于JAX库处理多维数组计算+大规模分布式训练+多主机间的数据同步和分片)

LLMs之Grok-1:checkpoint.py文件解读—加载和恢复机器学习模型检查点的工具(基于JAX库处理多维数组计算+大规模分布式训练+多主机间的数据同步和分片)

目录

checkpoint.py文件解读—加载和恢复机器学习模型检查点的工具(基于JAX库处理多维数组计算+大规模分布式训练+多主机间的数据同步和分片)

checkpoint.py文件解读—加载和恢复机器学习模型检查点的工具(基于JAX库处理多维数组计算+大规模分布式训练+多主机间的数据同步和分片)

概述

1、数据加载类和函数

2、状态管理类和函数

3、共享内存复制函数

4、工具函数

5、配置类

6、日志记录

全部代码


checkpoint.py文件解读加载和恢复机器学习模型检查点的工具(基于JAX库处理多维数组计算+大规模分布式训练+多主机间的数据同步和分片)

源码地址grok-1/checkpoint.py at main · xai-org/grok-1 · GitHub

checkpoint.py文件解读加载和恢复机器学习模型检查点的工具(基于JAX库处理多维数组计算+大规模分布式训练+多主机间的数据同步和分片)

概述

加载和恢复机器学习模型检查点的工具(基于JAX库处理多维数组计算+大规模分布式训练+多主机间的数据同步和分片)。它使用JAX库来处理多维数组计算,并涉及多主机间的数据同步。利用共享内存和多线程来加速文件操作,并使用JAX库来处理多维数组计算。它还支持多主机间的数据同步和分片,使得可以在大规模分布式训练中使用。

核心要点:

>> 使用共享内存来加速文件读写操作。

>> 使用多线程并行加载多个张量。

>> 使用JAX库处理多维数组计算。

>> 支持多主机间的数据同步和分片。

1、数据加载类和函数

load_tensors:加载张量数组。这个函数接受一个张量列表和一个目录,并异步加载这些张量。它使用线程池来并行加载(多维数组)。张量的数据被存储在多个文件中,每个文件对应一个张量的一部分。

fast_unpickle 和 fast_pickle: 这两个函数用于快速加载和保存Python对象,它们使用上述上下文管理器来处理共享内存。分别用于快速地从文件中读取(unpickle)和写入(pickle)对象。它们使用了copy_to_shm和copy_from_shm来加速文件操作。

fast_unpickle:快速解密

fast_pickle:快速加密

2、状态管理类和函数

restore:从检查点恢复状态。这个函数是主要的恢复函数,它接受检查点路径、状态形状、网格配置和其他参数,并恢复模型状态。

replace_with_load_state:用加载状态替换初始化状态。这个函数将初始状态替换为从检查点加载的状态。这个函数将初始状态中的某些张量替换为加载状态中的对应张量。它使用get_load_path_str来处理路径规则。

3、共享内存复制函数

copy_to_shm 和 copy_from_shm: 这两个上下文管理器用于在共享内存(/dev/shm/)和磁盘之间复制文件。它们用于加速Python pickle加载过程。这可以加速文件的读写操作,因为共享内存通常比磁盘快

copy_to_shm:复制到共享内存

copy_from_shm:从共享内存复制

4、工具函数

path_tuple_to_string:将路径元组转换为字符串。这个函数将一个由JAX的树结构表示的路径转换为一个字符串。将一个由多个元素组成的路径元组转换为一个字符串。路径元组的每个元素表示数据结构中的一个级别(如字典的键或对象的属性名)。

get_load_path_str:获取加载路径字符串。 这个函数处理加载路径的命名规则,包括重命名和排除。

multihost_utils.host_local_array_to_global_array: 这个函数可能用于在模型并行中转换本地数组为全局数组。

5、配置类

QuantizedWeight8bit:量化权重配置,通常用于减少模型的内存需求。

6、日志记录

logger 和 rank_logger: 这两个日志记录器用于记录信息和警告。

logger:日志记录器

rank_logger:rank日志记录器

全部代码

  1. # Copyright 2024 X.AI Corp.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. import contextlib
  16. import logging
  17. import math
  18. import os
  19. import pickle
  20. import re
  21. import shutil
  22. import sys
  23. import tempfile
  24. from concurrent.futures import ThreadPoolExecutor, wait
  25. from typing import Any, Optional
  26. import jax
  27. import numpy as np
  28. from jax.experimental import multihost_utils
  29. from model import QuantizedWeight8bit
  30. logger = logging.getLogger(__name__)
  31. rank_logger = logging.getLogger("rank")
  32. # Needed for loading the checkpoint with pickle.
  33. sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit
  34. @contextlib.contextmanager
  35. def copy_to_shm(file: str):
  36. if file.startswith("/dev/shm/"):
  37. # Nothing to do, the file is already in shared memory.
  38. yield file
  39. return
  40. tmp_dir = "/dev/shm/"
  41. fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
  42. try:
  43. shutil.copyfile(file, tmp_path)
  44. yield tmp_path
  45. finally:
  46. os.remove(tmp_path)
  47. os.close(fd)
  48. @contextlib.contextmanager
  49. def copy_from_shm(file: str):
  50. tmp_dir = "/dev/shm/"
  51. fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
  52. try:
  53. yield tmp_path
  54. shutil.copyfile(tmp_path, file)
  55. finally:
  56. os.remove(tmp_path)
  57. os.close(fd)
  58. def fast_unpickle(path: str) -> Any:
  59. with copy_to_shm(path) as tmp_path:
  60. with open(tmp_path, "rb") as f:
  61. return pickle.load(f)
  62. def fast_pickle(obj: Any, path: str) -> None:
  63. with copy_from_shm(path) as tmp_path:
  64. with open(tmp_path, "wb") as f:
  65. pickle.dump(obj, f)
  66. def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
  67. """Loads a set of arrays."""
  68. pool = ThreadPoolExecutor(max_workers=32)
  69. fs = list()
  70. num_tensors = 0
  71. num_replicas = 1
  72. data_model_shards = math.prod(mesh_config)
  73. if tensor_indices is None:
  74. iterator = enumerate(shaped_arrays)
  75. else:
  76. iterator = zip(tensor_indices, shaped_arrays)
  77. for i, t in iterator:
  78. if (i % num_replicas) == ((jax.process_index() // data_model_shards) % num_replicas):
  79. idx = (
  80. jax.process_index() // (num_replicas * data_model_shards) * data_model_shards
  81. + jax.process_index() % data_model_shards
  82. )
  83. fs.append(
  84. pool.submit(fast_unpickle, os.path.join(directory, f"tensor{i:05d}_{idx:03d}"))
  85. )
  86. num_tensors += 1
  87. else:
  88. fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype))
  89. wait(fs)
  90. return [f.result() for f in fs]
  91. def path_tuple_to_string(path: tuple) -> str:
  92. pieces = []
  93. for elem in path:
  94. if isinstance(elem, jax.tree_util.DictKey):
  95. pieces.append(elem.key)
  96. elif isinstance(elem, jax.tree_util.GetAttrKey):
  97. pieces.append(elem.name)
  98. else:
  99. assert isinstance(elem, (jax.tree_util.FlattenedIndexKey, jax.tree_util.SequenceKey))
  100. return "/".join(pieces)
  101. def get_load_path_str(
  102. init_path_str: str,
  103. load_rename_rules: Optional[list[tuple[str, str]]] = None,
  104. load_exclude_rules: Optional[list[str]] = None,
  105. ) -> Optional[str]:
  106. # Exclusion
  107. if load_exclude_rules is not None:
  108. for search_pattern in load_exclude_rules:
  109. if re.search(search_pattern, init_path_str):
  110. return None
  111. # Renaming
  112. load_path_str = init_path_str
  113. if load_rename_rules is not None:
  114. for search_pattern, replacement_pattern in load_rename_rules:
  115. if re.search(search_pattern, load_path_str):
  116. load_path_str = re.sub(search_pattern, replacement_pattern, load_path_str)
  117. break
  118. return load_path_str
  119. def replace_with_load_state(
  120. init_state: Any,
  121. load_state: Any,
  122. load_rename_rules: Optional[list[tuple[str, str]]] = None,
  123. load_exclude_rules: Optional[list[str]] = None,
  124. mesh_config: tuple = (1, 1),
  125. ) -> Any:
  126. flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
  127. flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
  128. load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
  129. replaced = []
  130. num_replicas = 1
  131. data_model_shards = math.prod(mesh_config)
  132. for i, (init_path, tensor) in enumerate(flatten_init):
  133. init_path_str = path_tuple_to_string(init_path)
  134. load_path_str = get_load_path_str(init_path_str, load_rename_rules, load_exclude_rules)
  135. if load_path_str is None:
  136. rank_logger.info(f"Excluded from restore: {init_path_str}.")
  137. replaced.append(tensor)
  138. elif load_path_str in load_map:
  139. if load_path_str == init_path_str:
  140. rank_logger.info(f"Restored from ckpt: {init_path_str}.")
  141. else:
  142. rank_logger.info(f"Restored from ckpt: {init_path_str} <-- {load_path_str}.")
  143. replaced.append(load_map[load_path_str])
  144. else:
  145. rank_logger.info(f"Not found in ckpt: {init_path_str}.")
  146. if (i % num_replicas) == ((jax.process_index() // data_model_shards) % num_replicas):
  147. replaced.append(tensor)
  148. else:
  149. replaced.append(np.zeros_like(tensor))
  150. return jax.tree_util.tree_unflatten(structure_init, replaced)
  151. def restore(
  152. checkpoint_path: str,
  153. state_shapes: Any,
  154. mesh,
  155. between_hosts_config,
  156. params_only,
  157. state_sharding,
  158. init_state: Optional[Any] = None,
  159. ) -> Any:
  160. ckpt_path = os.path.join(checkpoint_path, "ckpt-0")
  161. rank_logger.info("Loading checkpoint at {}".format(ckpt_path))
  162. ckpt_shapes = state_shapes
  163. ckpt_shapes_with_path, structure = jax.tree_util.tree_flatten_with_path(ckpt_shapes)
  164. ckpt_shapes_flat = [elem[1] for elem in ckpt_shapes_with_path]
  165. loaded_tensors = load_tensors(ckpt_shapes_flat, ckpt_path, between_hosts_config)
  166. state = jax.tree_util.tree_unflatten(structure, loaded_tensors)
  167. # Sanity check to give a better error message.
  168. ckpt_keys = set(state.params.keys())
  169. code_keys = set(state_sharding.params.keys())
  170. if ckpt_keys != code_keys and init_state is None:
  171. missing_in_ckpt = code_keys - ckpt_keys
  172. missing_locally = ckpt_keys - code_keys
  173. raise ValueError(
  174. "Parameters in the code are not matching checkpoint parameters.\n"
  175. "Params missing in checkpoint: {}\nParams missing in code: {}".format(
  176. missing_in_ckpt, missing_locally
  177. )
  178. )
  179. state_sharding = jax.tree_util.tree_map(
  180. lambda x: jax.sharding.PartitionSpec() if x is None else x,
  181. state_sharding,
  182. is_leaf=lambda x: x is None,
  183. )
  184. state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding)
  185. if params_only:
  186. state = state.params
  187. return state

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/332005?site
推荐阅读
相关标签
  

闽ICP备14008679号