AI | ML/AI 개발 | CUDA

[Pytorch Lightning] Pytorch-Lightning 사용해보기

깜태 2021. 7. 21. 19:14
728x90

최근에 PyTorch-lightning이라는 라이브러리를 알게 되었다.

Pytorch-lightning(PL)은 이름만 들어도, pytorch를 경량화시킨거 같다는 생각이 들었는데,

PL이 pytorch와 다른점이 무엇인지 알아보았다.

 

대략적인 설명은 홈페이지에서 볼 수 있다.

https://www.pytorchlightning.ai/

 

PyTorch Lightning

The ultimate PyTorch research framework. Scale your models, without the boilerplate.

www.pytorchlightning.ai

 

나의 느낀점은 Pytorch는 Python스러운 Numpy를 DL에 접목시킨 확장시킨 라이브러리였다면,

Pytorch Lightning은 "학습"의 한 사이클을 객체화 시켜놓은 느낌이였다.

 

느낀 점 1

 

보통 케라스는 모델을 설계하면, model.fit() 명령어로 알아서 모든게 해결된다.

마찬가지로, PL에서도 모델을 설계하고 Trainer라는 객체를 통해 Trainer.fit(model) 명령어로 모든 게 해결된다.

보통 PyTorch Tutorial을 보면 알겠지만,

 

파이토치에서 보통 학습의 구조는 아래의 구성을 지닌다.

  • 데이터셋 설계
  • 모델 설계
  • 손실함수 설계
  • 학습
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
encoder.cuda(0)
decoder.cuda(0)
# download on rank 0 only
if global_rank == 0:
mnist_train = MNIST(os.getcwd(), train=True, download=True)
# download on rank 0 only
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)

# train (55,000 images), val split (5,000 images)
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

# The dataloaders handle shuffling, batching, etc...
mnist_train = DataLoader(mnist_train, batch_size=64)
mnist_train
mnist_val = DataLoader(mnist_val, batch_size=64)
mnist_val

# optimizer
params = [encoder.parameters(), decoder.parameters()]
optimizer = torch.optim.Adam(params, lr=1e-3)

# TRAIN LOOP
model.train()
num_epochs = 1
for epoch in range(num_epochs):
  for train_batch in mnist_train:
    x, y = train_batch
    x = x.cuda(0)
    x = x.view(x.size(0), -1)
    z = encoder(x)
    x_hat = decoder(z)
    loss = F.mse_loss(x_hat, x)
    print('train loss: ', loss.item())
    loss.backward()
    optimizer.step()
optimizer.zero_grad()

# EVAL LOOP
model.eval()
with torch.no_grad():
  val_loss = []
  for val_batch in mnist_val:
    x, y = val_batch
    x = x.cuda(0)
    x = x.view(x.size(0), -1)
    z = encoder(x)
    x_hat = decoder(z)
    loss = F.mse_loss(x_hat, x)
    val_loss.append(loss)

  val_loss = torch.mean(torch.tensor(val_loss))
  model.train()

 

그리고, PL에서는 이를 한 클래스로 묶는다.

# model
class LITAutoEncoder(pl.LightningModule):
    def __init__(self):
		super().__init__()
		self.encoder = nn.Sequential(n.Linear(28 * 28, 64), n.ReLU(), nn.Linear(64, 3))
		self.decoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
		encoder.cuda(0)
		decoder.cuda(0)
    def forward(self, x):
		embedding = self.encoder(x)
		return embedding
    def configure_optimizers(self):
		params = [encoder.parameters(), decoder.parameters()]
		optimizer = torch.optim.Adam(self.parameters, lr=1e-3)
		optimizer = torch.optim.Adam(params, lr=1e-3)
		return optimizer
    def training_step(self, train_batch, batch_idx):
		x, y = train_batch
		x = x.cuda(0)
		x = x.view(x.size(0), -1)
		z =self.encoder(x)
        x_hat =self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log(‘train_loss’, loss)
        return loss
    def validation_step(self, val_batch, batch_idx):
    	x, y = val_batch
		x = x.cuda(0)
		x = x.view(x.size(0), -1)
		z =self.encoder(x)
        x_hat =self.decoder(z)
		loss = F.mse_loss(x_hat, x)
		val_loss.append(loss)
		self.log(‘val_loss’, loss)
    def backward(self, trainer, loss, optimizer, optimizer_idx):
		loss.backward()
‍
# train
model = LITAutoEncoder()
trainer = pl.Trainer()
trainer.fit(model, mnist_train, mnist_val)

 

한 모델에 대해 데이터셋, 옵티마이저, 손실함수까지 모든 걸 설정해서 객체화시키면 알아서 학습시켜준다.

좀 더 객체지향적이고, 깔끔하다고 볼 수 있을 것 같다.

그렇다면 의문이 드는게 모델에서 데이터셋까지 다 정해버리면,

데이터셋을 변경하는 경우에도 일일이 코딩을 해야되나? 란 생각도 들었다.

 

기존의 파이토치에서도 DataLoader를 썼던 것처럼 당연하게도 PL에서도 역시
pl.LightningDataModule 이라는 모듈을 통해 데이터셋을 변경하는 게 가능하다.

만약 내가 dm이라는 데이터셋 모듈을 정의하면, 그 모듈을 모델에 입력하고,
학습할 때  fit 명령어의 파라미터로 데이터셋 인스턴스를 넣어주면 된다.

dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.size(), dm.num_classes)
# Init trainer
trainer = pl.Trainer(
    max_epochs=3,
    progress_bar_refresh_rate=20,
    gpus=AVAIL_GPUS,
)
# Pass the datamodule as arg to trainer.fit to override model hooks :)
trainer.fit(model, dm)

 

이런 면에서 PL은 깔끔한 편이라는 생각이 들었다.

 

728x90