[ ]:

Table of Contents

  • 1  Create model

    • 1.1  Helper training functions

  • 2  Start Libvis

    • 2.1  Watch updates

    • 2.2  Configure poll delay

    • 2.3  Stop app

    • 2.4  Restart app

  • 3  Initialize pytorch model and data

  • 4  Configure optimizer to use live learning rate

    • 4.1  Set up slider for learning rate

    • 4.2  How this works?

  • 5  Train the model

  • 6  Button

  • 7  Custom serializer: watch weights live

[1]:
import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

from libvis import Vis

Using libvis for live monitoring of pytorch training

Create model

[2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

Helper training functions

We will need to put our loss function values somewhere. You can pass a list as an argument, but here let’s just declare a global variable

[ ]:
LOSSES = []
[3]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)

        LOSSES.append(loss.item())
        loss.backward()
        optimizer.step()
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

Start Libvis

[4]:
vis = Vis(ws_port=7700, vis_port=7000)
vis.configure_logging('WARNING')
print('Polling delay:', vis.app._watch_poll_delay)
Started libvis app at http://localhost:7000
Polling delay: 0.2

Watch updates

To monitor a variable in background use the vis.watch method. Libvis stores your object internally and sends it to visualization app every 0.2 seconds.

This method is usefull when you have large data that is frequently updated and you can miss some updates - like UDP internet protocol.

[5]:
vis.watch(LOSSES, 'loss')
[5]:
'Legi_0x7f865b2b2d10'
[6]:
ph = 0
while True:
    ph+=1
    LOSSES[:] = np.sin((np.arange(100)+ph)/10)
    try:
        time.sleep(.2)
    except KeyboardInterrupt:
        print('Interrupted')
        break
Interrupted
[13]:
LOSSES.clear()

Configure poll delay

This delay determines how often variables that are watched sent to visualization app.

Even if the data was not changed, it will be serialized, sent to the app and re-rendered. Sending them too often may result in high CPU usage. If you want real-time display, you should use direct updates

[6]:
vis.app._watch_poll_delay  = 1

What happens if we loose the connection? The webapp will start listening for a new one and will reconect upon new connection.

Stop app

You can stop the app by calling vis.stop(). It will stop both the http server for webapp and the websocket server. If you already have the app opened in browser, it will indicate that it lost connection.

[7]:
vis.stop()
Stopping webapp http server: `Vis.stop_http()`... OK
Stopping websocket server: `Vis.app.stop()`... OK

Restart app

To restart the same app, use vis.start(). The webapp will detect the new server and restore the connection.

[8]:
vis.start()
Started libvis app at http://localhost:7000

Initialize pytorch model and data

[9]:
use_cuda = False

device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

epochs=2
batch_size=3000
model = Net().to(device)
[10]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True, **kwargs)

Configure optimizer to use live learning rate

[11]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)

Now that we can read the data in realtime, let’s enable live updates to optimizer’s learning rate.

Integers in python are immutable, which means that we can’t just share a value of lr with libvis to apply updates from the webapp.

This means we have to have access to some mutable object, which stores the lr for optimizer. Such object is optimizer.param_groups.

[12]:
optimizer.param_groups[0].keys()
[12]:
dict_keys(['params', 'lr', 'betas', 'eps', 'weight_decay', 'amsgrad'])

Set up slider for learning rate

Libvis is meant to be a library of visualisation widgets, or modules. There are some pre-installed modules, like one used by default to represent a list. That LineGraph built-in module has no python part - it’s the built-in list type.

Let’s use a Slider submodule from module libvis.modules.uicontrols.

[13]:
from libvis.modules import uicontrols
[15]:
lr = optimizer.param_groups[0]['lr']
slider = uicontrols.Slider(value=lr, min=0, max=0.05)
vis.vars.lr = slider
[16]:
slider
[16]:
{'on_change': <function print>,
 'min': 0,
 'max': 0.05,
 'value': 0.003,
 'type': 'slider'}

The slider has on_change attribute with function print() as a value. Let’s change it to something useful.

[17]:
def on_slider(lr_new):
    optimizer.param_groups[0]['lr'] = lr_new
    print('Changed lr to', lr_new)

vis.vars.lr.on_change = on_slider
Changed lr to 0.012
Changed lr to 0.0235

How this works?

The vis.vars object is a dict, but a special one. First of all, you can assign keys to it by assigning attributes. Second, each time an attribute assigned, the value is sent to a websocket connected to vis.vars.

This special object is like a separate ‘channel’ for updates on some state. Libvis uses a separate library to sync states, called legimens, and vis.vars is legimens.Object.

Each module in libvis is a legimens.Object. To connect to an object client should specify a special token, called ref. You can get this value by calling legimens.Object.ref on the object.

[19]:
from legimens.Object import ref
ref(vis.vars.lr)
[19]:
'Legi_0x7f8659953450'

Now, anyone who connects through a websocket to localhost:7700/{ref(vis.vars.lr)} will receive updates of the slider values and will be able to update the slider. This is very useful when training on a remote server - you can monitor your process from any internet-enabled device.

Try tihs one:

[20]:
slider.value = 0.035

The value in your webapp should update to 0.035. It will not call the on_slider function, since it is called only on updates from the websocket.

[22]:
optimizer.param_groups[0]['lr']
[22]:
0.0235

If you want to act as if you are updates from websocket, use .vis_set(key, value) method. It also does conversion from string, since all updates are serialized to string values.

[30]:
slider.vis_set('value',  0.001 )
slider.vis_set('value', '0.001')
optimizer.param_groups[0]['lr']
Changed lr to 0.001
Changed lr to 0.001
[30]:
0.001

Train the model

Now it’s time to play with learning rate while training a model!

[ ]:
for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
[26]:
plt.plot(LOSSES)
plt.yscale('log')
plt.grid()
../_images/examples_libvis_with_pytorch_45_0.png

Button

Let’s look at another member of libvis.modules.uicontrols - Button. It’s pretty straightforward - you can call a function upon press of a button in app.

[41]:
def increment_silder():
    print('Incrementing!')
    vis.vars.lr.vis_set('value', vis.vars.lr.value + 0.01)

vis.vars.button = uicontrols.Button(label='Press me!', on_press=increment_silder)

Changed lr to 0.0225
Incrementing!
Changed lr to 0.0325
Incrementing!
Changed lr to 0.0425

Custom serializer: watch weights live

[ ]:
from libvis import interface as ifc

def ser(x):
    return x.clone().cpu().detach().numpy().tolist()

ser(model.fc2.weight)
[45]:
ifc.add_serializer(type(model.fc2.weight), ser)
[57]:
vis.watch(model.fc2.weight, 'fc2')
[57]:
'Legi_0x7f8648685e50'

Now, if you restart the training, you will be able to see the model weights update in real-time! They don’t change very much when the model is already trained, so you might want to make learning rate larger.

You can also re-create the model and re-connect the variable.

And now you have a fully-interactive remote dashboard for your learning process!

image.png

[64]:
vis.stop()
Stopping webapp http server: `Vis.stop_http()`... OK
Stopping websocket server: `Vis.app.stop()`... OK
[ ]: