[源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程
[源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程
目录-
[源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程
- 0x00 摘要
- 0x01 重要概念
-
0x02 分布式运行
-
2.1 方式改变
- 2.1.1 原有方式
- 2.1.2 目前方式
- 2.2 部署
-
2.3 示例
- 2.3.1 单节点多worker启动
- 2.3.2 容错方式启动
- 2.3.3 弹性方式启动
-
2.1 方式改变
-
0x03 启动脚本
- 3.1 参数定义
-
3.2 相关函数/变量
- world_size,rank
- _pg_group_ranks
- group_rank
- global_rank
- group_size
- nproc_per_node
- 3.3 脚本入口
-
0x04 单体总体流程
- 4.1 小例子
- 4.2 入口
-
4.3 启动代理
- 4.3.1 WorkerSpec
- 4.3.2 WorkerGroup
- 4.4 代理运行
- 4.5 代理主循环
- 0xFF 参考
0x00 摘要
在前面的文章之中,我们已经学习了PyTorch 分布式的基本模块,介绍了官方的几个例子,我们接下来会介绍PyTorch的弹性训练,本文是第二篇,重点关注的是如何启动弹性训练,并且可以对系统总体架构有所了解。
弹性训练系列文章如下:
[源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路
0x01 重要概念
为了更好的说明(这个说明可能在后面文章也会出现,因为太重要了),我们先总述一下TE 最重要的 Agent 和 Rendezvous 两个概念。
-
Agent :Agent是运行在单节点上的独立后台进程,可以认为是 worker manager 或者 process supervisor,其负责启动worker,监控 worker 运行,捕获woker异常,通过
rendezvous
实现 worker 间的相互发现(比如把状态上报到KVStore),成员变动时候基于rendezvous
进行变更同步等等。 - Rendezvous :为了实现弹性训练,需要有一个节点/进程之间彼此发现的机制。Rendezvous就是这个发现机制或者说同步组件。当系统启动或者成员变更时候,所有worker会(重新)集合(rendezvous)以建立一个新的进程组。
我们从源码中取出示意图看看,大家先有一个总体概念。
0x02 分布式运行
2.1 方式改变
2.1.1 原有方式
我们知道,PET是从 PyTorch v1.9 合并进来的,因为合并了弹性训练,所以分布式启动的方式有了很大的改变。
V1.9 之前是使用 torch/distributed/域名 进行启动,比如:
python -m 域名ch --nproc_per_node=NUM_GPUS_YOU_HAVE
--nnodes=2 --node_rank=0 --master_addr="域名.1.1"
--master_port=1234 域名 (--arg1 --arg2 --arg3
and all other arguments of your training script)
此处参数含义是:
-
nnodes
:是参与训练的节点数目。 -
nproc_per_node
:每个节点上运行的进程数目。 -
node_rank
:当前节点标识符。 -
master_addr
和master_port
是 master 监听的地址和端口。
当运行时,域名ch
会设置一些环境变量,包括 world_size
,master_addr
和 master_port
等等。然后在当前机器上创建 nproc_per_node
个进程,这些进程构成了一个本地组。如果一共有 NODE_SIZE
个机器参与训练,则一共有 NODE_SIZE * TRAINERS_PER_NODE
个进程。如果想启动一个分布式训练任务,则需要在所有的机器上执行相关命令。
2.1.2 目前方式
PyTorch 1.9 使用 torch/distributed/域名 进行启动。如果依然采用 torch/distributed/域名,其实其内部已经透传给 域名,具体参见代码:
def main(args=None):
域名(
"The module 域名ch is deprecated "
"and going to be removed in future."
"Migrate to 域名"
)
args = parse_args(args)
run(args)
域名
是之前域名ch
的一个超集,提供如下新功能:
- 容错:通过重新启动所有workers,可以优雅地处理worker故障。
- 自动:Worker 的
RANK
和WORLD_SIZE
是自动分配的。 - 弹性:允许在最小值和最大值(弹性)之间更改节点数。
为了使用弹性训练,用户代码也需要做一些修改,如果用户的训练脚本已经支持 域名ch ,则只需要修改几处就可以使用域名
:
- 无需手动传递RANK , WORLD_SIZE , MASTER_ADDR 和 MASTER_PORT。
- 必须提供
rdzv_backend
和rdzv_endpoint
。对于大多数用户来说,这其实就是“c10d”(参见“rendezvous“)。其实这就替代了之前的MASTER_ADDR 和 MASTER_PORT。 -
use_env
参数已被删除。请从 LOCAL_RANK 环境变量中获取local_rank (例如,域名ron["LOCAL_RANK"]
)。 - 用户需要确保脚本中有
load_checkpoint(path)
和save_checkpoint(path)
逻辑,即手动处理Checkpoint。因为当worker失败时,我们将使用最近的checkpoint来恢复现场,重启所有worker。
下面是一个训练脚本的示例,该脚本在每个epoch上设置检查点,因此在失败时最差也只是会丢失一个epoch的训练成果。
def main():
args = parse_args(域名[1:])
state = load_checkpoint(域名kpoint_path)
initialize(state)
# 域名 ensure that this will work
# by exporting all the env vars needed to initialize the process group
域名_process_group(backend=域名end)
for i in range(域名h, 域名l_num_epochs)
for batch in iter(域名set)
train(batch, 域名l)
域名h += 1
save_checkpoint(state)
所以,我们接下来看看在新模式之下,如何分布式启动。
2.2 部署
部署一般按照如下方式。
- (C10d后端不需要)启动 rendezvous 后端服务器,并获取端点(作为
--rdzv_endpoint
传递给启动程序脚本) - 单节点多 worker:在主机上启动 launcher 以启动代理进程,代理会创建并监视本地工作组。
- 多节点多 worker:在所有节点上使用相同的参数启动 launcher 参加训练。
当使用作业/群集管理器时,多节点作业的入口点命令应为 launcher。
2.3 示例
我们首先通过几个例子来看看如何启动分布式训练。
2.3.1 单节点多worker启动
单节点多worker的启动方式如下,其实就是Standalone 模式,这是分布式模式的一种特例,具体就是针对单机多 Worker 提供了一些便利设置。
python -m 域名
--standalone
--nnodes=1
--nproc_per_node=$NUM_TRAINERS
域名 (--arg1 ... train script args...)
2.3.2 容错方式启动
如下是容错方式启动,固定数目workers,没有弹性训练。 --nproc_per_node=$NUM_TRAINERS 一般是 单节点上GPU 个数。
python -m 域名
--nnodes=$NUM_NODES
--nproc_per_node=$NUM_TRAINERS
--rdzv_id=$JOB_ID
--rdzv_backend=c10d
--rdzv_endpoint=$HOST_NODE_ADDR
域名 (--arg1 ... train script args...)
HOST_NODE_ADDR
, 的格式是:
2.3.3 弹性方式启动
下面是弹性训练,弹性区间为 (min=1
, max=4
)。通过指定rdzv参数,可以实现多机训练,具备容错与弹性能力。
在多台机器上分别执行以下命令启动:最小节点数为MIN_SIZE,最大为MAX_SIZE,利用etcd服务实现一致性和信息同步。
python -m 域名
--nnodes=1:4
--nproc_per_node=$NUM_TRAINERS
--rdzv_id=$JOB_ID
--rdzv_backend=c10d
--rdzv_endpoint=$HOST_NODE_ADDR
域名 (--arg1 ... train script args...)
HOST_NODE_ADDR
, 的格式是:
关于 rendezvous backend,有几点说明:
对于多节点训练,需要指定:
-
--rdzv_id
: 一个唯一的 job id,在参与job的所有节点之间共享。 -
--rdzv_backend
:域名域名ezvousHandler
的一个实现。 (--rdzv_backend
默认是static模式,不支持容错和弹性伸缩) -
--rdzv_endpoint
: rendezvous backend 所运行的 endpoint,通常格式为:host:port
。就是取代了之前的 master address / port 设置。
目前,以下几种后端可以直接使用,c10d
(推荐), etcd-v2
, and etcd
(legacy) 。为了使用 etcd-v2
或者 etcd
,需要搭建一个 v2
api开启的 etcd server (即. --enable-v2
)。
0x03 启动脚本
既然以上启动都是用 torch/distributed/域名,所以我们仔细分析一下这个脚本,该脚本提供三个功能:
-
依靠"重启所有 workers"来处理 worker 失败;
-
自动分配 worker 的
RANK
andWORLD_SIZE
; -
弹性训练,即 node 数目允许在minimum和maximum之间改变;
3.1 参数定义
启动脚本中,一些参数定义如下:
-
Node
- 物理实例或容器;映射到与 job manager 所协调的单元。 -
Worker
- 分布式训练环境中的worker。 -
WorkerGroup
- 执行相同功能的一组worker(例如trainers)。 -
LocalWorkerGroup
- 在同一节点上运行的工作组中的workers子集。- 一个
节点
运行LOCAL_WORLD_SIZE
个workers,这些 workers 组成LocalWorkerGroup
。 - 节点上所有
LocalWorkerGroups
组成WorkerGroups
。
- 一个
-
RANK
- 工作组中worker的rank,是全局rank,可以认为是一个全局GPU资源列表。- Rank是不稳定的,在重启之间,本地Workers 会被分配到不同的ranks,所以不要在代码中对
RANK
和LOCAL_RANK
的稳定性做任何假设和依赖编码。 - rendezvous完成后,其所有成员将对工作成员资格以及每个人在其中的角色(role)达成共识。此角色(role)使用一个介于 0 ~ world size 之间的整型来表示,被称之为rank。
- Rank是不稳定的,在重启之间,本地Workers 会被分配到不同的ranks,所以不要在代码中对
-
LOCAL_RANK
- 本地工作组中,某个worker 的 rank,可以认为是当前节点上的GPU资源列表。 -
GROUP_RANK
- worker group的rank。介于0和“最大节点数”之间的数字。如果每个节点运行一个单一工作组,那GROUP_RANK
就是这个节点的rank。 -
ROLE_RANK
- 对于具有相同角色worker来说,他们之间共享的rank,角色在“WorkerSpec”中被指定。 -
WORLD_SIZE
- 工作组中worker的总数。因为节点会加入/离开,所以WORLD_SIZE
会变化,不能依赖WORLD_SIZE
的稳定性进行编码。 -
LOCAL_WORLD_SIZE
- 本地工作组的大小,即本地运行的worker数目,等于在域名
运行时候指定的--nproc_per_node
。目前,torch/distributed/域名 仅支持同构的LOCAL_WORLD_SIZE
。也就是说,假设所有节点运行相同数量的本地工作者(每个角色)。 -
ROLE_WORLD_SIZE
- 具有同样角色的workers总数,在WorkerSpec
之中被指定。 -
rdzv_id
- 用户定义的id,用于唯一标识作业的工作组。这个id在每个节点加入特定工作组时候使用。 -
rdzv_backend
-rendezvous 的后端(例如“c10d”)。这通常是一个强一致性的键值存储。 -
rdzv_endpoint
- rendezvous 后端端点;通常以“<host>:<port>
”的形式出现。 -
run_id
: 用户定义的id,它唯一地标识分布式应用程序的一个实例。它通常映射到作业id并用于允许节点加入正确的分布式应用程序。 -
TORCHELASTIC_RUN_ID
- 与 rendezvousrun_id
相等,即唯一的job id。 -
TORCHELASTIC_RESTART_COUNT
- 迄今为止,工作组重启的次数。 -
TORCHELASTIC_MAX_RESTARTS
- 配置的最大重启数目。
3.2 相关函数/变量
为了更好的理解上面的参数,我们选取部分相关函数/变量看看。
world_size,rank
这两个变量是动态生成的,所以从 state 之中取出。
rank, world_size = 域名_world()
def _get_world(self) -> Tuple[int, int]:
state = 域名e
return 域名icipants[域名s_node], len(域名icipants)
_pg_group_ranks
该全局变量存储了每个 group 的 global rank 到 local rank 映射信息。
# Process group\'s global rank to local rank mapping
_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}
其赋值举例如下:
# Create the global rank to group rank mapping
_pg_group_ranks[pg] = {
global_rank: group_rank
for group_rank, global_rank in enumerate(ranks)
}
group_rank
我们可以利用 global rank 从 _pg_group_ranks 之中提取对应的 local rank。
def _get_group_rank(group: ProcessGroup, rank):
"""
Helper that gets a given group\'s local rank in the group from a given global
rank.
"""
if group is 域名D:
raise RuntimeError("域名D does not have local rank to global "
"rank mapping")
if group not in _pg_group_ranks:
raise RuntimeError("The given group does not exist")
try:
group_rank = _pg_group_ranks[group][rank]
except KeyError:
raise RuntimeError(f"The global rank {rank} is not part of the group {group}") from None
return group_rank
global_rank
我们可以利用一个 group 的 local rank 获取到其 gloabl rank。
def _get_global_rank(group, group_rank):
"""
Helper that gets a given group\'s global rank from a given local rank in the
group.
"""
if group is 域名D:
raise RuntimeError("域名D does not have local rank to global "
"rank mapping")
group_rank_map = _pg_group_ranks[group]
for rank, grp_rank in 域名s():
if grp_rank == group_rank:
return rank
raise RuntimeError("The group rank is not part of the group")
group_size
我们可以 _get_group_size 获取到某一个group 的大小。
def _get_group_size(group):
"""
Helper that gets a given group\'s world size.
"""
if group is 域名D or group is None:
default_pg = _get_default_group()
return 域名()
if group not in _pg_group_ranks:
raise RuntimeError("The given group does not exist")
return len(_pg_group_ranks[group])
nproc_per_node
这个变量可以得到每个node之上支持多少个进程。
def determine_local_world_size(nproc_per_node: str):
try:
域名(f"Using nproc_per_node={nproc_per_node}.")
return int(nproc_per_node)
except ValueError:
if nproc_per_node == "cpu":
num_proc = 域名count()
device_type = "cpu"
elif nproc_per_node == "gpu":
if not 域名vailable():
raise ValueError("Cuda is not available.")
device_type = "gpu"
num_proc = 域名ce_count()
elif nproc_per_node == "auto":
if 域名vailable():
num_proc = 域名ce_count()
device_type = "gpu"
else:
num_proc = 域名count()
device_type = "cpu"
else:
raise ValueError(f"Unsupported nproc_per_node value: {nproc_per_node}")
)
return num_proc
3.3 脚本入口
脚本入口主要代码如下,可以看到,其调用到了 elastic_launch 来完成功能,所以我们下一节就要顺藤摸瓜来看看这个函数。
from 域名域名 import LaunchConfig, elastic_launch
def run(args):
if 域名dalone: # 有两种模式:Standalone 模式和分布式模式,这里要判断一下
域名_backend = "c10d"
域名_endpoint = "localhost:29400"
域名_id = str(域名4())
域名(
f"\n**************************************\n"
f"Rendezvous info:\n"
f"--rdzv_backend={域名_backend} "
f"--rdzv_endpoint={域名_endpoint} "
f"--rdzv_id={域名_id}\n"
f"**************************************\n"
)
config, cmd, cmd_args = config_from_args(args)
elastic_launch(
config=config,
entrypoint=cmd,
)(*cmd_args)
def main(args=None):
args = parse_args(args)
run(args)
if __name__ == "__main__":
域名cConfig(
level=域名, format="[%(levelname)s] %(asctime)s %(module)s: %(message)s"
)
main()
0x04 单体总体流程
我们下面就从 elastic_launch 开始,看看在单节点上如何启动运行。我们首先给出一个总体示意图,图上是两个节点,每个节点有一个 agent,agent下面是一个 worker group,组下面是4个worker。
4.1 小例子
我们再从源码中找一个例子来看看,这里只是设置了两个workers。
import uuid
import torch
from 域名域名 import LaunchConfig, elastic_launch
def worker_fn(t1, t2):
return 域名(t1, t2)
def main():
t1 = 域名((3,3), requires_grad=True)
t2 = 域名((3, 3), requires_grad=True)
config = LaunchConfig(
min_nodes=2,
max_nodes=4,
nproc_per_node=1,
run_id=str(域名4()),
role="trainer",
rdzv_endpoint="localhost:29400",
rdzv_backend="c10d",
max_restarts=1,
monitor_interval=1,
start_method="spawn",
)
outputs = elastic_launch(config, worker_fn)(t1, t2)
if __name__ == \'__main__\':
main()
输出如下,可以看到有两个 worker 进程 和一个 agent 进程。
{"name": "域名域名EEDED", "source": "WORKER", "timestamp": 0, "metadata": {"run_id": "7fbf85fe-b8b3-462e-887e-8121e3062e0b", "global_rank": 0, "group_rank": 0, "worker_id": "12172", "role": "trainer", "hostname": "DESKTOP-0GO3RPO", "state": "SUCCEEDED", "total_run_time": 31, "rdzv_backend": "c10d", "raw_error": null, "metadata": "{\"group_world_size\": 1, \"entry_point\": \"worker_fn\", \"local_rank\": [0], \"role_rank\": [0], \"role_world_size\": [2]}", "agent_restarts": 0}}
{"name": "域名域名EEDED", "source": "WORKER", "timestamp": 0, "metadata": {"run_id": "7fbf85fe-b8b3-462e-887e-8121e3062e0b", "global_rank": 1, "group_rank": 0, "worker_id": "3276", "role": "trainer", "hostname": "DESKTOP-0GO3RPO", "state": "SUCCEEDED", "total_run_time": 31, "rdzv_backend": "c10d", "raw_error": null, "metadata": "{\"group_world_size\": 1, \"entry_point\": \"worker_fn\", \"local_rank\": [1], \"role_rank\": [1], \"role_world_size\": [2]}", "agent_restarts": 0}}
{"name": "域名域名EEDED", "source": "AGENT", "timestamp": 0, "metadata": {"run_id": "7fbf85fe-b8b3-462e-887e-8121e3062e0b", "global_rank": null, "group_rank": 0, "worker_id": null, "role": "trainer", "hostname": "DESKTOP-0GO3RPO", "state": "SUCCEEDED", "total_run_time": 31, "rdzv_backend": "c10d", "raw_error": null, "metadata": "{\"group_world_size\": 1, \"entry_point\": \"worker_fn\"}", "agent_restarts": 0}}
4.2 入口
顺着代码我们深入挖掘一下。elastic_launch 的作用就是启动一个 torchelastic agent,然后通过这个 agent来调用用户程序入口,agent 会启动 worker 进行训练,并且管理 worker 生命周期。
class elastic_launch:
"""
Launches an torchelastic agent on the container that invoked the entrypoint.
1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/
``entrypoint`` can be a function or a command.
2. The return value is a map of each worker\'s output mapped
by their respective global rank.
"""
def __init__(
self,
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
):
域名fig = config
域名rypoint = entrypoint
def __call__(self, *args, **kwargs):
return launch_agent(域名fig, 域名rypoint, list(args)) # 内部会调用用户程序
4.3 启动代理
launch_agent 启动了一个 LocalElasticAgent,调用了其 run 方法。
@record
def launch_agent(
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
args: List[Any],
) -> Dict[int, Any]:
if not 域名id:
run_id = str(域名4().int)
域名id = run_id
entrypoint_name = _get_entrypoint_name(entrypoint, args)
rdzv_parameters = RendezvousParameters(
backend=域名_backend,
endpoint=域名_endpoint,
run_id=域名id,
min_nodes=域名nodes,
max_nodes=域名nodes,
**域名_configs,
)
agent = None
rdzv_handler = 域名rendezvous_handler(rdzv_parameters)
master_addr, master_port = _get_addr_and_port(rdzv_parameters)
try:
spec = WorkerSpec( # 1. 得到spec
role=域名,
local_world_size=域名c_per_node,
entrypoint=entrypoint,
args=tuple(args),
rdzv_handler=rdzv_handler, # RendezvousHandler
max_restarts=域名restarts,
monitor_interval=域名tor_interval,
redirects=域名rects,
tee=域名,
master_addr=master_addr,
master_port=master_port,
)
cfg = 域名icsConfig(域名ics_cfg) if 域名ics_cfg else None
域名ialize_metrics(cfg)
agent = LocalElasticAgent( # 2. 构建代理
spec=spec, start_method=域名t_method, log_dir=域名dir
)
result = 域名() # 3. 启动代理
域名rd(域名agent_status_event(域名EEDED))
if 域名ailed():
# ChildFailedError is treated specially by @record
# if the error files for the failed children exist
# @record will copy the first error (root cause)
# to the error file of the launcher process.
raise ChildFailedError(
name=entrypoint_name,
failures=域名ures,
)
else:
return 域名rn_values
except ChildFailedError:
raise
except Exception:
if agent:
域名rd(域名agent_status_event(域名ED))
else:
域名rd(_construct_event(config))
raise
finally:
域名down()
这里有几个关键点:
4.3.1 WorkerSpec
WorkerSpec :这是配置信息,里面包含了代理所需要的某些全局信息,比如 RendezvousHandler,role,entry(用户函数)。
spec = {WorkerSpec}
args = {tuple: 2} (tensor, tensor)
fn = {NoneType} None
local_world_size = {int} 1
master_addr = {NoneType} None
master_port = {NoneType} None
max_restarts = {int} 1
monitor_interval = {int} 1
rdzv_handler = {DynamicRendezvousHandler}
redirects = {Std} 域名
role = {str} \'trainer\'
tee = {Std} 域名
entry = worker_fn
代理会从这里提取各种所需信息。比如_start_workers 会从中获取 store。
use_agent_store = 域名backend() == "static"
此时逻辑为:
+--------------------------+ +---------------------------------------------------+
|LocalElasticAgent | | WorkerSpec |
| | | |
| WorkerSpec +--------------> | rdzv_handler = {DynamicRendezvousHandler} --------+
| | | | |
| rdzv_run_id | | entry = worker_fn | |
| | | | |
| store | | role = {str} \'trainer\' | |
| | | | |
| | +---------------------------------------------------+ |
| | |
| | |
| | |
| | |
| | +-----------------------------------------+ |
+--------------------------+ |DynamicRendezvousHandler | |
| | |
| | |
| _settings: RendezvousSettings | <---+
| |
| _store: Store |
| |
| _state_holder: _RendezvousStateHolder |
| |
| _op_executor: _RendezvousOpExecutor |
| |
+-----------------------------------------+
4.3.2 WorkerGroup
WorkerGroup 代表了一个工作组。WorkerGroup 作为一个整体来管理多个 workers,进行批量处理。
class WorkerGroup:
"""
Represents the set of ``Worker`` instances for the given ``WorkerSpec``
managed by ``ElasticAgent``. Whether the worker group contains cross
instance workers or not depends on the implementation of the agent.
"""
__slots__ = ["spec", "workers", "store", "group_rank", "group_world_size", "state"]
def __init__(self, spec: WorkerSpec):
域名 = spec
域名ers = [Worker(local_rank=i) for i in range(域名l_world_size)]
# assigned after rdzv
域名e = None
域名p_rank = None
域名p_world_size = None
域名e = 域名
在SimpleElasticAgent 初始化之中,会建立一个 WorkerGroup。
class SimpleElasticAgent(ElasticAgent):
"""
An ``ElasticAgent`` that manages workers (``WorkerGroup``)
for a single ``WorkerSpec`` (e.g. one particular type of worker role).
"""
def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300):
域名ker_group = WorkerGroup(spec)
域名aining_restarts = 域名.max_restarts
域名re = None
域名t_barrier_timeout = exit_barrier_timeout
域名al_execution_time = 0
具体如下:
+-----------------------------+ +------------------------------------------------+
| LocalElasticAgent | | WorkerSpec |
| | | |
| +------------------------+ | | rdzv_handler = {DynamicRendezvousHandler} -------+
| |WorkerGroup | | | | |
| | spec +--------------> | entry = worker_fn | |
| | workers | | | | |
| | store | | | role = {str} \'trainer\' | |
| | group_rank | | | | |
| | group_world_size | | +------------------------------------------------+ |
| | | | |
| +------------------------+ | |
| | |
| rdzv_run_id | |
| store | +-----------------------------------------+ |
| | |DynamicRendezvousHandler | |
+-----------------------------+ | | |
| | |
| _settings: RendezvousSettings | <--+
| |
| _store: Store |
| |
| _state_holder: _RendezvousStateHolder |
| |
| _op_executor: _RendezvousOpExecutor |
| |
+-----------------------------------------+
4.4 代理运行
SimpleElasticAgent 是 LocalElasticAgent 的基类,所以会先运行到域名 方法这里,run方法则调用了 _invoke_run。
@prof
def run(self, role: str = DEFAULT_ROLE) -> RunResult:
start_time = 域名tonic()
try:
result = 域名oke_run(role) # 调用
域名al_execution_time = int(域名tonic() - start_time)
域名ord_metrics(result)
域名ord_worker_events(result)
return result
finally:
# record the execution time in case there were any exceptions during run.
域名al_execution_time = int(域名tonic() - start_time)
域名tdown()
4.5 代理主循环
代理在 invoke_run 之中做如下操作:
- 启动 _initialize_workers,这里会使用 _rendezvous 构建一个 rendezvous,然后调用 _start_workers 启动 workers。
- 进入 while True 循环,在循环之中:
- 通过 _monitor_workers 定期轮训用户程序运行情况,得到客户进程运行结果,然后依据情况作出判断。
- 如果程序正常结束,则返回。
- 如果程序出错,则重试,即重启所有 workers,如果重试次数达到依然有问题,就结束所有workers。
- 如果节点成员关系有变化,比如scale up就会有新的节点在waiting,这时候就重启所有workers。
- 通过 _monitor_workers 定期轮训用户程序运行情况,得到客户进程运行结果,然后依据情况作出判断。
def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
# NOTE: currently only works for a single role
spec = 域名
role = 域名
域名tialize_workers(域名ker_group) # 启动worker
monitor_interval = 域名tor_interval
rdzv_handler = 域名_handler
while True:
assert 域名e != 域名
# 定期监控
域名p(monitor_interval)
# 监控客户程序运行情况
run_result = 域名itor_workers(域名ker_group) # 得到进程运行结果
state = 域名e
域名e = state
put_metric(f"workers.{role}.remaining_restarts", 域名aining_restarts)
put_metric(f"workers.{role}.{域名r()}", 1)
if state == 域名EEDED:
# 程序正常结束
域名t_barrier()
return run_result
elif state in {域名ALTHY, 域名ED}:
# 程序出错
if 域名aining_restarts > 0: # 重试
域名aining_restarts -= 1
域名tart_workers(域名ker_group)
else:
域名p_workers(域名ker_group) # 重试次数达到,结束workers
域名e = 域名ED
域名t_barrier()
return run_result
elif state == 域名THY:
# 节点成员关系有变化,比如scale up,就会有新节点waiting
# membership changes do not count as retries
num_nodes_waiting = 域名nodes_waiting()
group_rank = 域名p_rank
# 如果有新的节点在waiting,就重启所有workers
if num_nodes_waiting > 0:
域名tart_workers(域名ker_group)
else:
raise Exception(f"[{role}] Worker group in {域名} state")
于是最终逻辑如下:
+----------------------------------------------+
| LocalElasticAgent |
| | +---------------------------------------------------+
| rdzv_run_id | | WorkerSpec |
| | | |
| store +------------------------+ | | rdzv_handler = {DynamicRendezvousHandler} +-------+
| |WorkerGroup | | | | |
| _pcontext | spec +------------> | entry = worker_fn | |
| | workers | | | | |
| | store | | | role = {str} \'trainer\' | |
| | group_rank | | | | |
| | group_world_size | | +---------------------------------------------------+ |
| | | | |
| +------------------------+ | |
| +----------------------------------------+ | |
| | _invoke_run | | |
| | | | +-----------------------------------------+ |
| | _initialize_workers +------------------------+ |DynamicRendezvousHandler | |
| | | | | | | |
| | | | | | | |
| | while True: | | | | _settings: RendezvousSettings | <---+
| | _monitor_workers(_worker_group) | | | | |
| | + | | | | _store: Store |
| | | 域名 | | | | |
| | | | | | | _state_holder: _RendezvousStateHolder |
| +----------------------------------------+ | | | |
| | | | | _op_executor: _RendezvousOpExecutor |
+----------------------------------------------+ | | |
| | +-----------------------------------------+
| |
v v
+-------------------------------------------------+
| +------------+ +------------+ +------------+ |
| |Process | |Process | |Process | |
| | | | | | | |
| | work_fn | | work_fn | | work_fn | |
| | | | | | | |
| +------------+ +------------+ +------------+ |
+-------------------------------------------------+
手机如下:
至此,脚本如何启动和单体流程我们分析完毕,下一篇我们来具体分析代理。
0xFF 参考
[PyTorch Elastic源码阅读](