少女祈祷中...

教程,来自官网、blog等。

Hydra

get start

读取配置文件的一个包,可以读取制指定文件夹下的制定配置文件,安装方法:

1
pip install hydra-core

这会安装hydra, omegacong等。

创建folder conf,在里面创建我们的配置文件config.yaml:

1
2
3
4
5
known_host:
host: 120.76.43.27
port: 62222
user: chengyiqiu
pwd: secert

运行下面的代码:

1
2
3
4
5
6
7
8
9
10
11
12
import hydra
from omegaconf import DictConfig, OmegaConf


@hydra.main(version_base=None, config_path='./conf', config_name='config')
def test(cfg: DictConfig) -> None:
print(OmegaConf.to_yaml(cfg))
if cfg.known_host.port == 62222:
print('ok')


test()

若是将cgf返回,得到的是None,但是可以在test()内部对cfg的内部进行判定。

也可以访问上级目录:

1
2
3
@hydra.main(version_base=None, config_path='..', config_name='test')
def test_(cfg: DictConfig) -> None:
print(OmegaConf.to_yaml(cfg))

也可以将config转变成Object,但是不能超过这个函数的生命周期,否则会变成None,在生命周期内部,可以将config进行传参,只要没超出生命周期即可:

1
2
3
4
5
6
7
8
9
10
11
12
@hydra.main(version_base=None, config_path='..', config_name='test')
def test_(cfg: DictConfig):
# ---------lifetime start ---------
cfg = OmegaConf.to_object(cfg) # object
train(cfg)
print()
# ---------lifetime over ---------


if __name__ == '__main__':
cfg = test_() # None
print(cfg)

重载

可以通过命令行传入参数来重载配置:

1
2
3
4
5
6
7
8
9
(tutorials) chengyiqiu@chengyiqiu:~/code/tutorials/Hydra$ python get_start.py known_host.port=8888
known_host:
host: 120.76.43.27
port: 8888
user: chengyiqiu
pwd: secert

not ok

1
2
3
4
5
6
7
8
9
10
11
@hydra.main(version_base=None, config_path='./conf', config_name='config')
def test(cfg: DictConfig) -> None:
print(OmegaConf.to_yaml(cfg))
if cfg.known_host.port == 62222:
print('ok')
else:
print('not ok')


if __name__ == '__main__':
test()

此举动不会更改原始的yaml中的参数配置,但是当进程运行结束后,会创建日志,重载的配置:

image-20240403165108222

image-20240403165118016

封装

若是想要在多个配置文件中进行选择,可以用对配置文件做进一步封装,如下目录树:

1
2
3
4
5
6
7
8
9
10
11
(tutorials) chengyiqiu@chengyiqiu:~/code/tutorials/Hydra$ tree ./
./
├── conf
│   └── config.yaml
├── get_start.py
└── host
├── config.yaml
└── user
├── user1.yaml
└── user2.yaml

config.yaml中的内容:

1
2
defaults:
- user: user1

user1.yaml中的内容:

1
2
3
4
5
user:
username: chengyiqiu
password: 1234
ipv4: 120.76.43.27
port: 62222

Code:

1
2
3
4
5
6
7
8
@hydra.main(version_base=None, config_path='./host', config_name='config')
def test_default_config(cfg: DictConfig):
print(OmegaConf.to_yaml(cfg))
return


if __name__ == '__main__':
test_default_config()

能够定向到user1.yaml,同样的方式,也可以使用重载来重新选择对应的配置文件;

1
2
3
4
5
6
7
(tutorials) chengyiqiu@chengyiqiu:~/code/tutorials/Hydra$ python get_start.py user=user2
user:
user:
username: qcy
password: 12
port: 1
ipv4: 1.1.1.1

也可以重载新的配置文件中的参数:

1
2
3
4
5
6
7
8
(tutorials) chengyiqiu@chengyiqiu:~/code/tutorials/Hydra$ python get_start.py user=user2 user.user.port=1111
user:
user:
username: qcy
password: 12
port: 1111
ipv4: 1.1.1.1

pytorch_lighting

train

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class MyLightningModule(L.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
self.validation_step_outputs = []

def forward(self, x):
return self.model(x)

def training_step(self, batch):
x, y = batch
y_p = self.forward(x)
loss = torch.nn.functional.cross_entropy(y_p, y)
return loss

def validation_step(self, batch):
x, y = batch
y_p = self.forward(x)
loss = torch.nn.functional.cross_entropy(y_p, y)
pred_labels = torch.argmax(y_p, dim=1)
correct = (pred_labels == y).sum().item()
accuracy = correct / x.shape[0]
self.validation_step_outputs.append(accuracy)
self.log('val_loss', loss, prog_bar=True)
self.log('val_accuracy', accuracy, prog_bar=True)
return accuracy

def test_step(self, batch):
x, y = batch
y_p = self.forward(x)
loss = torch.nn.functional.cross_entropy(y_p, y)
pred_labels = torch.argmax(y_p, dim=1)
correct = (pred_labels == y).sum().item()
accuracy = correct / x.shape[0]
sample_imgs = x[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('example_images', grid, self.current_epoch)
self.log('test_loss', loss, prog_bar=True)
self.log('test_accuracy', accuracy, prog_bar=True)
return {"test_loss": loss, "test_accuracy": accuracy}

继承L.LightningModule,重写4个方法即可训练。

singularity

1
2
3
4
5
singularity pull --arch arm64 arm64-ubuntu.sif library://ubuntu:22.04

singularity build --sandbox container-sandbox/ arm64-ubuntu.sif

singularity shell --writable --fakeroot container-sandbox/