{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", " RNN: Прогнозирование следующей буквы\n", "
\n", "\n", "\n", "\n", "Описание: NN_RNN_Torch.html," ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Библиотеки" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "\n", "#import ctypes\n", "#ctypes.cdll.LoadLibrary('caffe2_nvrtc.dll')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Загружаем текст" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Akunin_Priklyucheniya-Erasta-Fandorina_1_Azazel.txt: 53879 words\n", "Akunin_Provincialnyy-detektiv_1_Pelagiya-i-belyy-buldog.txt: 68067 words\n", "Akunin_Zhanry_2_Shpionskiy-roman.txt: 63546 words\n", "Strugackiy_Maksim-Kammerer_1_Obitaemyy-ostrov.txt: 96243 words\n", "Strugackiy_Maksim-Kammerer_2_Zhuk-v-muraveynike.txt: 53396 words\n", "Strugackiy_Maksim-Kammerer_3_Volny-gasyat-veter.txt: 41175 words\n", "Strugackiy_Ponedelnik-nachinaetsya-v-subbotu.txt: 60048 words\n", "Strugackiy_Ulitka-na-sklone.txt: 60777 words\n", "Tolstoy_Voyna-i-mir-Tom-1.txt: 104440 words\n", "Tolstoy_Voyna-i-mir-Tom-2.txt: 116938 words\n", "Tolstoy_Voyna-i-mir-Tom-3.txt: 127024 words\n", "Tolstoy_Voyna-i-mir-Tom-4.txt: 104756 words\n", "Vasilev_Akula-pera-v-Mire-Fayrolla_1_Igra-ne-radi-igry.txt: 81571 words\n", "Vasilev_A-Smolin-vedmak_1_Chuzhaya-sila.txt: 109389 words\n", "Vasilev_Kovcheg-5-0_1_Mesto-pod-solncem.txt: 131243 words\n", "Vasilev_Ucheniki-Vorona_1_Zamok-na-Voroney-gore.txt: 97007 words\n", "chars: 35\n", "|борис акунин\n", "азазель\n", "глава первая \n", "в которой описывается некая циничная выходка\n", "в понедельник мая го|\n", "|\n", "фланконада удар наносимый в бок противника под самый локоть.\n", "огонь антония некроз тканей гангрена.\n", "|\n", "8467974\n" ] } ], "source": [ "import re, zipfile, numpy as np, time\n", "from time import perf_counter as tm # таймер sec\n", "\n", "CHARS = \" .абвгдежзийклмнопрстуфхцчшщъыьэюя\\n\" # алфавит\n", "charID = { c:i for i,c in enumerate(CHARS) } # буква в номер\n", "\n", "def preprocess(txt):\n", " \"\"\" Буквы не из алфавита заменяем пробелами \"\"\"\n", " txt = txt.lower().replace('ё','e')\n", " txt = txt.lower().replace('?','.')\n", " txt = ''.join( [c if c in CHARS else ' ' for c in txt] ) \n", " txt = re.sub(' +', ' ', txt).replace(' .', '.')\n", " return re.sub('\\n\\s+', '\\n', txt)\n", "\n", "def load_Sultan():\n", " with open(\"C:/!/Python/Data/NLP/saltan.txt\", \"r\", encoding='utf-8-sig') as file:\n", " return preprocess ( file.read() )\n", " \n", "def load_Books():\n", " txt = \"\"\n", " with zipfile.ZipFile(\"C:/!/Data/nlp/books/books.zip\") as myzip:\n", " for fname in myzip.namelist():\n", " print(fname, end=\": \")\n", " with myzip.open(fname) as myfile:\n", " st = preprocess ( myfile.read().decode(\"utf-8\") ) \n", " print(len(st.split()), \"words\") \n", " txt += st\n", " return txt\n", " \n", "#text = load_Sultan() \n", "text = load_Books() \n", " \n", "print(f\"chars: {len(CHARS)}\") \n", "print(f\"|{text[:100]}|\") \n", "print(f\"|{text[-100:]}|\") \n", "print(len(text))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Параметры" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "LENGTH, STEP, NUM, = 25, 25, 20\n", "E_DIM, H_DIM, NUM_LAYERS, DROP = 10, 250, 1, 0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Готовим данные" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "trn: torch.Size([254037, 25]) torch.Size([254037, 25]) tensor(34) tensor(34)\n", "val: torch.Size([84679, 25]) torch.Size([84679, 25]) tensor(34) tensor(34)\n", "len(textID):8467974 num:338716, STEP:25\n", "torch.Size([254037, 25]) torch.Size([254037, 25]) tensor(34) tensor(34)\n", "[3, 16, 18, 10, 19, 0, 2, 12, 21, 15, 10, 15, 34, 2, 9, 2, 9, 7, 13, 30, 34, 5, 13, 2, 4, 2, 0, 17, 7, 18, 4, 2, 33, 0, 34, 4, 0, 12, 16, 20, 16, 18, 16, 11, 0, 16, 17, 10, 19, 29, 4, 2, 7, 20, 19, 33, 0, 15, 7, 12, 2, 33, 0, 24, 10, 15, 10, 25, 15, 2, 33, 0, 4, 29, 23, 16, 6, 12, 2, 34, 4, 0, 17, 16, 15, 7, 6, 7, 13, 30, 15, 10, 12, 0, 14, 2, 33, 0, 5, 16]\n", "то человек бывший для мен\n", "о человек бывший для меня\n", "Wall time: 11.9 s\n" ] } ], "source": [ "%%time\n", "\n", "textID = [ charID[c] for c in text ]\n", "\n", "num_seq = int((len(textID)-LENGTH)/STEP)-1 # число последовательностей\n", "\n", "X_dat = torch.empty (num_seq, LENGTH, dtype=torch.long)\n", "Y_dat = torch.empty (num_seq, LENGTH, dtype=torch.long)\n", "\n", "for i in range(num_seq): \n", " X_dat[i] = torch.tensor(textID[i*STEP: i*STEP+LENGTH], dtype=torch.long)\n", " Y_dat[i] = torch.tensor(textID[i*STEP+1: i*STEP+LENGTH+1], dtype=torch.long)\n", " \n", "idx = torch.randperm( len(X_dat) ) # перемешанный список индексов\n", "X_dat = X_dat[idx]\n", "Y_dat = Y_dat[idx]\n", "\n", "num_trn = int(0.75*len(X_dat))\n", "X_trn, Y_trn = X_dat[:num_trn], Y_dat[:num_trn]\n", "X_val, Y_val = X_dat[num_trn:], Y_dat[num_trn:]\n", "\n", "print(\"trn:\", X_trn.shape, Y_trn.shape, X_trn.max(), Y_trn.max()) \n", "print(\"val:\", X_val.shape, Y_val.shape, X_val.max(), Y_val.max()) \n", " \n", "def tensor2st(t):\n", " return ''.join( [ CHARS[i] for i in t])\n", "\n", "print(f\"len(textID):{len(textID)} num:{num_seq}, STEP:{STEP}\")\n", "print(X_trn.shape, Y_trn.shape, X_trn.max(), Y_trn.max()) \n", "print(textID[:100])\n", "print(tensor2st(X_trn[0]))\n", "print(tensor2st(Y_trn[0]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Модель\n", "\n", "Для экономи памяти и ускорения, модель внутри функции forward вычисляет ошибку" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "C:35, E:10, H:250, LAYERS:1\n", "torch.Size([5, 35, 2]) torch.Size([1, 5, 250])\n" ] } ], "source": [ "class RNN(nn.Module):\n", " def __init__(self, E, H, LAYERS=1): \n", " super(RNN, self).__init__() \n", " self.H = H\n", " self.mode = 1\n", " self.rnn = nn.GRU(E, H, num_layers=LAYERS)\n", " \n", " def forward(self, x, h0=None): # (L,B,E), (1,B,H) \n", " if self.mode == 1: return self.forward1( x, h0 )\n", " if self.mode == 2: return self.forward2( x, h0 ) \n", " if self.mode == 3: return self.forward3( x, h0 ) \n", " \n", " def forward1(self, x, h0=None): # (L,B,E), (LAYERS,B,H) \n", " return self.rnn(x, h0) # (L,B,H), (LAYERS,B,H) \n", " \n", " def forward2(self, x, h0=None):\n", " (L, B, E), H = x.size(), self.H \n", "\n", " self.h = [ torch.zeros(1,B,H, device = x.device) if h0 is None else h0 ] \n", " for i in range(L): # по всем ячейкам\n", " _, h = self.rnn( x[i].view(1,B,E), self.h[i] ) \n", " self.h.append( h.clone() ) # запоминем скрытые состояния\n", " \n", " for h in self.h: h.retain_grad() # будут помнить свои градиенты \n", "\n", " self.y = torch.cat(self.h[1:], dim=0) # (L,B,H) объединяем все выходы\n", " self.y.retain_grad() # будут помнить свои градиенты \n", " \n", " return self.y, self.h[-1] # (L,B,H), (1,B,H) \n", "\n", " def forward3(self, x, h0=None):\n", " (L, B, E), H = x.size(), self.H \n", "\n", " self.h = [ torch.zeros(1,B,H, device = x.device) if h0 is None else h0 ] \n", " for i in range(L): # по всем ячейкам\n", " y, h = self.rnn( x[i].view(1,B,E), self.h[i] ) \n", " self.h.append( h.clone() ) # запоминем скрытые состояния\n", " self.y = y if i==0 else torch.cat([self.y, y], dim=0)\n", " \n", " for h in self.h: h.retain_grad() # будут помнить свои градиенты \n", "\n", " self.y.retain_grad() # будут помнить свои градиенты \n", " return self.y, self.h[-1] # (L,B,H), (1,B,H) \n", " \n", " \n", " def stat(self): \n", " if self.mode == 1:\n", " return\n", " self.V_h, self.G_h = torch.zeros(len(self.h)), torch.zeros(len(self.h))\n", " for i,h in enumerate(self.h): \n", " self.G_h[i] = 0 if h.grad is None else ((h.grad**2).mean())**0.5 \n", " self.V_h[i] = ((h.detach()**2).mean())**0.5 \n", " self.G_y = ((self.y.grad**2).mean(dim=(1,2)))**0.5 \n", "\n", "class Model(nn.Module):\n", " def __init__(self, C, E, H, LAYERS = 1, DROP=0, EMBED=True): \n", " \"\"\"число классов, размерности эмбединга и скрытого состояния, слоёв\"\"\" \n", " super(Model, self).__init__() \n", " print(f\"C:{C}, E:{E}, H:{H}, LAYERS:{LAYERS}\")\n", " self.C = C \n", " self.EMBED = EMBED\n", " \n", " if EMBED:\n", " self.emb = nn.Embedding(C, E, scale_grad_by_freq=True) \n", " self.rnn = RNN(E,H,LAYERS) \n", " self.fc = nn.Linear(H, C) \n", " self.drop = nn.Dropout(DROP)\n", " \n", " def forward(self, x, h0=None, NUM=1): # (B,L), (1,B,H) \n", " if self.EMBED:\n", " x = self.emb ( x.t() ) # (L,B,E) \n", " else:\n", " x = torch.zeros(len(x), x.size(1), self.C).scatter_(2, x.unsqueeze(2), 1.) # (B,L,C)\n", " x = x.transpose(0,1).contiguous()\n", "\n", " yr, hr = self.rnn ( x, h0 ) # (L,B,H), (1,B,H) \n", " \n", " y = yr[-NUM : ] # (NUM,B,H) последние выходы \n", " y = self.drop(y)\n", " y = self.fc(y) # (NUM,B,C) \n", " return y.permute(1,2,0), hr # (B,C,NUM), (LAYERS,B,H) \n", " \n", "model = Model(len(CHARS), E_DIM, H_DIM, NUM_LAYERS, EMBED=True) # экземпляр сети\n", "\n", "gpu = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "cpu = torch.device(\"cpu\")\n", "\n", "model.to(gpu)\n", "\n", "y, h_rnn = model( x = torch.zeros(5, LENGTH, dtype=torch.long).to(gpu), NUM=2 )\n", "print(y.shape, h_rnn.shape)\n", "\n", "losses = [] # история ошибок для графика" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "emb.weight : 350 shape: (35, 10) \n", "rnn.rnn.weight_ih_l0 : 7500 shape: (750, 10) \n", "rnn.rnn.weight_hh_l0 : 187500 shape: (750, 250) \n", "rnn.rnn.bias_ih_l0 : 750 shape: (750,) \n", "rnn.rnn.bias_hh_l0 : 750 shape: (750,) \n", "fc.weight : 8750 shape: (35, 250) \n", "fc.bias : 35 shape: (35,) \n", "total : 205635\n" ] } ], "source": [ "tot = 0\n", "for k, v in model.state_dict().items():\n", " pars = torch.tensor(list(v.shape)).prod(); tot += pars\n", " print(f'{k:20s} :{pars:7d} shape: {tuple(v.shape)} ')\n", "print(f\"{'total':20s} :{tot:7d}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Веса" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[' :1.85', '.:4.26', 'а:2.70', 'б:4.27', 'в:3.32', 'г:4.20', 'д:3.67', 'е:2.67', 'ж:4.80', 'з:4.26', 'и:2.91', 'й:4.69', 'к:3.56', 'л:3.22', 'м:3.65', 'н:2.91', 'о:2.37', 'п:3.77', 'р:3.31', 'с:3.14', 'т:2.98', 'у:3.75', 'ф:6.42', 'х:4.98', 'ц:5.82', 'ч:4.39', 'ш:4.96', 'щ:5.86', 'ъ:8.22', 'ы:4.18', 'ь:4.16', 'э:5.82', 'ю:5.34', 'я:4.02', '\\n:5.27']\n", "tensor(1.0000)\n", "Wall time: 2min 11s\n" ] } ], "source": [ "%%time\n", "weight = torch.ones(len(CHARS), dtype=torch.float)\n", "for c in textID:\n", " weight[c] += 1\n", "weight /= len(text)\n", "\n", "weight = -weight.log_()\n", "print ( [f'{c}:{weight[i]:.2f}' for i,c in enumerate(CHARS)] )\n", "weight /= weight.sum()\n", "print(weight.sum())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Обучение" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "optimizer = torch.optim.Adam(model.parameters())\n", "CE_loss = nn.CrossEntropyLoss(weight.to(gpu))\n", "#CE_loss = nn.CrossEntropyLoss()\n", "\n", "def calc_acc(y, yb, NUM): # (B,C,NUM), (B,L) \n", " \"\"\" Вычисляем точность \"\"\" \n", " _,idx = y[:,:,-1].detach().topk(1, dim=1) # (B,) макс.индекс \n", " return (idx.view(-1) == yb[:,-1]).float().mean() # точность определения класса\n", "\n", "def fit(model, X,Y, batch_size=64, NUM=1, train=True): \n", " model.train(train) # важно для Dropout, BatchNorm\n", " sumL, sumA, iters = 0, 0, int(len(X)/batch_size) \n", " \n", " num_oks = torch.zeros(len(CHARS), dtype=torch.float)\n", " num_tot = torch.zeros_like(num_oks) \n", " \n", " start1, start2 = tm(), tm()\n", " for it in range(iters): # примеры разбиты на пачки \n", " \n", " xb = X[it*batch_size: (it+1)*batch_size].to(gpu) # (B,L)\n", " yb = Y[it*batch_size: (it+1)*batch_size].to(gpu) # (B,L) \n", " \n", " y, _ = model(xb, NUM = NUM) \n", " L = CE_loss(y, yb[:, -NUM:]) \n", " \n", " sumL += L.detach().item()\n", " sumA += calc_acc(y, yb, NUM)\n", " \n", " if train: # в режиме обучения \n", " optimizer.zero_grad() # обнуляем градиенты \n", " L.backward() # вычисляем градиенты \n", " optimizer.step() # подправляем параметры\n", " if it+1 == iters:\n", " model.rnn.stat() # вычисляем статистики \n", " \n", " if tm() - start2 > 1 or it+1==iters:\n", " print('\\r', f\"{'trn' if train else 'val'} { iters*(tm()-start1)/(it+1) :.1f}s {iters}:{100*(it+1)/iters:.0f}% \", end='')\n", " print(f\"loss: {sumL/(it+1):.4f} acc: {sumA/(it+1):.4f}\", end='')\n", " start2 = tm()\n", " \n", " return sumL/iters, sumA/iters\n", " \n", "def params_stat(model):\n", " for k, v in model.state_dict().items():\n", " pars = torch.tensor(list(v.shape)).prod()\n", " aV = ((v.detach()**2).mean())**0.5\n", " aG = 0 #((v.grad**2).mean())**0.5\n", " print(f'v:{aV:.4f} g:{aG:.4f} {k:20s} :{pars:7d} shape: {tuple(v.shape)} ') \n", "\n", "def calc_acc2(idx, yb): \n", " oks = yb[idx==yb]\n", " for c in range(len(CHARS)):\n", " num_tot[c] += len( yb [yb==c] )\n", " num_oks[c] += len( oks[oks==c])\n", " \n", " num_tot[num_tot==0] = 1\n", " return num_oks/num_tot, num_tot/(iters*batch_size*yb.size(1)) \n", " \n", "#L_trn, _, _, _, _ = fit(model, X_trn, Y_trn, 100, train=False) \n", "#losses.append([L_trn,L_trn])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA6IAAAD8CAYAAABtlBmdAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAgAElEQVR4nO3dfWwc933n8c939oHLJ5EyJVGyaEtKJNZPSe2GUJ1zDjAOV8B2kvofH05B0ODyR42mF6AOckCM/pGkQA7IXwHaOIlhoEbOaNAgaNrC7TnN5e5iOwmanCU/ylYsyold07JJiZZILp9393d/zCw5O5rdHUrLIbn7fgELzv7mt7O/lUaz+vD3MOacEwAAAAAAafG2ugEAAAAAgM5CEAUAAAAApIogCgAAAABIFUEUAAAAAJAqgigAAAAAIFUEUQAAAABAqpoGUTO7wcx+amZnzOxVM/uzmDp3m9mMmb0YPL68Oc0FAAAAAOx02QR1SpK+6Jx73sz6JZ0ys584516L1PuZc+4TrW8iAAAAAKCdNO0Rdc6965x7Ptiek3RG0sHNbhgAAAAAoD0l6RFdY2aHJd0h6Vcxuz9qZi9JOi/pvznnXo15/YOSHpSk3t7ej9x0000bbS8AAAAAYAc4derURefc3rh95pxLdBAz65P0jKT/7pz7+8i+XZIqzrmimd0n6S+dc8caHW9sbMydPHky0XsDAAAAAHYWMzvlnBuL25do1Vwzy0n6oaTvRUOoJDnnZp1zxWD7KUk5M9tzDW0GAAAAALSpJKvmmqS/lnTGOfeNOnX2B/VkZseD4063sqEAAAAAgPaQZI7oXZL+SNIrZvZiUPbnkm6UJOfco5IekPQ5MytJWpR0wiUd8wsAAAAA6ChNg6hz7ueSrEmdRyQ90qpGAQAAAEC7W11d1cTEhJaWlra6KdekUChoZGREuVwu8Ws2tGouAAAAAKA1JiYm1N/fr8OHDyuY6bjjOOc0PT2tiYkJHTlyJPHrEi1WBAAAAABoraWlJQ0NDe3YECpJZqahoaEN9+oSRAEAAABgi+zkEFp1NZ+BIAoAAAAASBVBFAAAAAA60OXLl/Xtb397S96bIAoAAAAAHaheEC2Xy5v+3qyaW8cfPvJzDfXmNTrcr2PD/Rod7tPRfX3qyfNHBgAAAGDne/jhh/XGG2/o9ttvVy6XU19fnw4cOKAXX3xR3/72t/XVr35Ve/bs0enTp/WRj3xEf/M3f9OyOa2kqhgrpYoODfVqfHJOvzg3rZVyZW3fyO5uP5zu6yOgAgAAAGiJv/inV/Xa+dmWHvOW63fpK5+8te7+r3/96zp9+rRefPFFPf300/r4xz+u06dP68iRI3r66af1wgsv6NVXX9X111+vu+66S7/4xS/0sY99rCVtIz3FyGc9ffNTd0iSSuWK3np/QeOTczo7WdT4VFHjk3P6+fjF+IA63KfRff5PAioAAACAneL48eM19wI9fvy4RkZGJEm333673nzzTYJoWrIZTx/c26cP7u3TPbetl0cD6tnJOZ2bKtYEVLMgoO7r19EgoI4O9+uD+3oJqAAAAADWNOq5TEtvb2/N866urrXtTCajUqnUsvciDV2lRgH1zWk/oI5P+QF1fLKoZ8cvaLXsJNUG1GPBMN/R4X4d3den7nxmiz4RAAAAgE7S39+vubm5LXlvgmiLZTOeju7zh+XeGypfLVf01nR4iG98QL1hd0/N/NPR4X59cC8BFQAAAEBrDQ0N6a677tJtt92m7u5uDQ8Pp/be5pxL7c3CxsbG3MmTJ7fkvbcTP6DOa3yy6A/xnZrT+OScfntx/oqA6i+MREAFAAAA2sGZM2d08803b3UzWiLus5jZKefcWFx9ekS3WC7j6ei+fh3d1697P7ReXg2o1fmn1UWSnjl7ZQ/q6PB6D+qxff4Q30KOgAoAAABgeyKIblPhgHrfhw6slV8RUINhvk+/fkGlynpAvfG6Hh0LVu8loAIAAADYTgiiO0yjgPrmxfma+adnJ+f09OtTsQHV70UloAIAAABIH0G0TeQynr8C73C/pPWAulKq7UE9N3VlQPWCgBqef3ps2F8RmIAKAAAAoNUIom0un10PqB+PBNQ31xZJmtP4lL+ab1xADd9ihoAKAAAA4FoRRDtUPutpdLhfo3UC6tngNjPngoD601/HB9RqD+rRfQRUAAAAAMkQRFEjHFDDogF1PFjJ9//+ekrlUEA9NNQb3Ac16EHd168P7O0loAIAAAA7XF9fn4rFYkuORRBFIo0C6m8vztfcYubs5Jz+T52AWh3eS0AFAAAAOhdBFNckn/X0O/v79Tv7awPqcqmsNy8uBLeY8UNqXEA9PNSro6GAOjrcryN7CKgAAADAZvvSl76kQ4cO6U//9E8lSV/96ldlZnr22Wd16dIlra6u6mtf+5ruv//+lr+3OedaftAkxsbG3MmTJ7fkvbF1lktl/faiv0jSeDDM9+zUnN6aXrgioB4LzT8dHfZ7ULuyBFQAAAC0hzNnzujmm2/2n/zoYem9V1r7Bvs/JN379bq7X3jhBT300EN65plnJEm33HKL/uVf/kWDg4PatWuXLl68qDvvvFPj4+Mys4ZDc2s+S8DMTjnnxuLq0yOKVHVlM7pp/y7dtH9XTXk1oK7NPw0C6v8+s96DmvFMh4Z6QkN8/dV8CagAAADAxt1xxx2amprS+fPndeHCBe3evVsHDhzQF77wBT377LPyPE/vvPOOJicntX///pa+N0EU20LSgFqdixoXUEf3BfNPg9V8j+whoAIAAGCHaNBzuZkeeOAB/d3f/Z3ee+89nThxQt/73vd04cIFnTp1SrlcTocPH9bS0lLL35cgim2tUUD9zQV/kaRzwfzTs5Nz+l+vvacgn9YE1NHhPh0loAIAAAA1Tpw4oT/+4z/WxYsX9cwzz+gHP/iB9u3bp1wup5/+9Kd66623NuV9CaLYkbqyGd18YJduPlAbUJdWy+ur+E4WNT4VH1APD/XoWBBQjwWrAR/Z06t81tuCTwMAAABsjVtvvVVzc3M6ePCgDhw4oE9/+tP65Cc/qbGxMd1+++266aabNuV9CaJoK4Vc/YD6mwvzGp8K5p9Ozun1OgE1PP+UgAoAAIB298or64sk7dmzR//6r/8aW69V9xCVCKLoEIVcRrdcv0u3XF8/oFZ7UX/93px+/Op6QM16psN7/PugVuefHttHQAUAAACuFkEUHS1pQD3bIKCODvfpaDDMd3S4X4eHCKgAAABAIwRRIEajgPrGhWJo/mlRr52f1Y9OvycXE1D9eaj+ar4EVAAAAEQ552RmW92Ma+Kq/xHeAIIosAGFXEa3Xj+gW68fqCkPB9TqLWbiAuqRPb3+LWaCgDo63KfDe3qVyxBQAQAAOk2hUND09LSGhoZ2bBh1zml6elqFQmFDryOIAi3QKKCemyqGbjFT1Kt1Amq157S6mi8BFQAAoL2NjIxoYmJCFy5c2OqmXJNCoaCRkZENvYYgCmyiQi6j2w4O6LaD8QG1Orx3fLKo0+dn9NTpd9cCai4T9KDu8wPqaLCSLwEVAACgPeRyOR05cmSrm7ElCKLAFqgXUBdXgiG+awF1rn5ADd1iZnS4T4eGCKgAAADYGZoGUTO7QdITkvZLqkh6zDn3l5E6JukvJd0naUHSf3HOPd/65gLtrTvfOKBW55+OT87plYkZPfVKfEAdDYb3HiOgAgAAYBtK0iNakvRF59zzZtYv6ZSZ/cQ591qozr2SjgWP35f0neDnzjX9hlQYlLoHJS+z1a1Bh0sSUM9OFnVuKj6gfmBPn44O94UCar8ODfUQUAEAALAlmgZR59y7kt4NtufM7Iykg5LCQfR+SU84f93eX5rZoJkdCF67M33n30mlJUnmh9Hu66SeIakn9DNaVn3evVvKMOoZm69RQK2dgzqnlycu63++vP5PshpQw/NPjw336/BQj7IEVAAAAGyiDaUlMzss6Q5Jv4rsOijp7dDziaCsJoia2YOSHpSkG2+8cWMtTZNz0h8+Ii2+Ly1MB49ge/Yd6b1X/O3SUv1jFAbqBNXrrgyu1VCbzaf3GdHWuvMZfWhkQB8aqQ2oCyslvTE17/egTs3p3GRRL01c1j+HAmo+4+kDe3t1NDT/9Og+AioAAABaJ3EQNbM+ST+U9JBzbja6O+YlV9zV1Dn3mKTHJGlsbGzjdz1Ni5n04f/UvN7KQiisBj8XL0Wevy8VJ6WpX/vPV+frH69rl9+bGtfL2rM7PrzmNna/HnS2nnw2UUAdbxBQ1xdJCob4XkdABQAAwMYkCqJmlpMfQr/nnPv7mCoTkm4IPR+RdP7am7fN5Xv8x8AG7pmzuhQTXt8PtsPPp6WL437Zylz94+V66wTVofXe12jPbK772j872kqjgOrfA9Uf5js+WdQL/3ZJ//TS+j/vcEAdDYb3HhvuI6ACAACgriSr5pqkv5Z0xjn3jTrVnpT0eTP7vvxFimZ29PzQzZQrSLnrpV3XJ39NaSUUVqcjQ4Yv1Za9/1u/3vJM/eNlu+v3stYbQpzv9XuK0VF68ll9eGRQHx4ZrCmfXy4FiyT580/Hp5IF1NHhPt1IQAUAAOh45lzjEbJm9jFJP5P0ivzbt0jSn0u6UZKcc48GYfURSffIv33LZ51zJxsdd2xszJ082bAKrkV5NRgmHA2vdYYQL0xLS5frHy/TlWyhpnC47eonvHaY+eVSsEiSH1Crq/m+c3lxrU4+6+kDe3pr5p9W74Oa8ThfAAAA2oWZnXLOjcXuaxZENwtBdBsql/wwWje8xgwhXrwkuUr88bxcgoWaIkOICwOE1zZUDajh+6DGBdQP7u2rmX86OtyvG6/rIaACAADsQARRbJ5KZT28RoNr3fmv70uuHH88L7u+YFOjVYbD+wuDksdQz52ouFzSG6GAenbSn4caF1BHh/vWbjFDQAWwLTgnlZallaK0PCetzAfbRf9n7Pa8v/bD2nawz8v533/du4PbxgXbhcE6ZYNStmur/wQAoCGCKLaXSkVanq2/yvDa80i4razGH8+8SHhtNv91yP8C9zLpfm4kVgz1oJ5LEFDD90EloAKoyzlpdcEPgMtzQQicD0LhXGg7GiIb1K+Ukr23l5Xyff60lXyfv/ZCV1+w3SeVV/xf7C5eCh6XpaUZxdyEYF2uNxJQBxoE2VB51y5GHwFIBUEUO59z/n8CGq0yfMX812n/iz2WBV/I0aAauX1OTXjdLWU2dOtdtFg4oI6vDfOtDahd1SG+oYA6OtyvGwiowM5Tqfi3PUsUDqNBMab+SrH+dJKoTFdtUFzb7o0Jk/1XBsvo9tX0XlbKfhiNBtTFS0HZ5dqytfJLje91bl6DntZIaI2W0QsLYAMIouhMzvn/8ajbyxo3/3W68Zd3YSDZKsPhRZ2y+fQ+c4cqLpdCwXRubTXf8zPrf5dd4SG+BFRgc5RLtaHvikC4wV7HRvfejsr1bCAcRsNkTC9lJrd5f05pWF2MhNZL9UNrTVmzXtiemCHDdUJrOOB27WIaDdCBCKLARqwsNF9leG1/UN7oP0v5/tqg2mwIcfd1/m1+cM3mllb9VXwnaxdKigbUo/tq55+ODvdpZDcBFR2gtNJ8/mLDuY+R+qXF5u9ZVRMOe/1r5VX3OvYy3aJVKhX/FnBX9L6Gg2y0vNoL2+Dv37zQ0OGEva/VMr4TgR2LIApsttWlZKsMh3tmV+bqHy/X22Su627Vrjo8JOW60/u8O1w0oJ6dKupcJKAWctUe1H4dDXpPR4f7dMPuHnkEVGyFVi2ME65Td/pChHmhoNgbP/S0JkxW6/THb+d66B1rR6tLG+x9DZU16oXNdtdfsKnR4k5dA5xnwBYjiALbUWmldkhwNKjGlS3P1D9etjvBQk2RIcT5XhasCJlbWl3rNR2fLOpssP1unYB6bLhPo/v8nwRUXKEVC+OEw+SGFsbJ1QmK0XDYYE5jOFhmC1wrsHmqixg2HUZ8+cqy1YUGBw7Wg0ja+xoupxcWaAmCKNAuyqvBMOFmCzVF7vVaT6arcS9r3BDirv6O+w/p7FoPajD/tE5APbqvT4eGetWdyyiX8dSV9ZTPespngp/Bdi7rqStSls96ygU/4163dryMR+DdLDtpYZxmvY4sKINOsdYLm7T3tboi8eXG/z6z3Q1Ca7Q8tJ9eWKAGQRToZJWy/6Vbc2/XJkOIFy/V/4L2cskWagqH28JAW4bX2aVVjU8WdW7KD6hnJ+f0zqVFLZcqWi5VtFIqa6Vc0WrZqVxp3bU261lNQM2HQ2817IaDbiT4ru2LBN/qsaKvy8WE6a7IsbKeydL+O15bGKdeTyML4wCoo1LxrxGJhhFfrg2yzXphG95Gp8EqxUyxQRsiiALYmErF/9KN3su14RDi9yVXjj+el43c67VReA32FQbb6rfK5YrTSqmilVJFy+Xy2vZqOSgvl4PwGioP1VsuVbRSru5br7dSrqy9bq08tL/6uprXrJW17vpvpppgGw27+aynbq+sXd6ydnnL6rcl9XpL6tOSerWkHltSj1tUj1tUIXh0VRbUVVlUvrygfHlBufKCsqV5ZUvzypQWlCk3WOG6tnXx921kYRwAV6O0fGU4bdT7Gi5r2AtbaDJkOBJmq3ULA1yTsG01CqLcFBHAlTxvPSwmVZ3j03CV4SCwvv8baeI5f7uyGn8884Iv3HBQbRReh/wv5236ZZzxTN35jLrzGUnbo/erUnFardSG07UAu1pWaWVR5aU5VZbmVFkuqrLk9zK65XnZalG2UpS3Mi9vdV6Zkv/IlhaUK80rV15QvrSgruUF5SuLKlQWlVOdv+uIsjPNq1tFFbTgCrqkguZdtxbUq6KGNO8Kmg/K5lXdLqiobi24gopa37+S6VY5U1CulFVennJlT/mV5r3IXaEQ7ZdL+eyC8pkldWXfj+1ZXn9dRrms1ZZnMsFrTNlM+/yCBehI2S6pf9h/bMRaL2ySYcSXpcv/Jr37ctAL22jEhkmFXcnmvl6xInF3W45Yws5AEAXQGp63PncmKef84ZFNVxme9r+Qz7/gb9dd6TNYmKLuKsNx93rd3R7DIpstjBNzOw5vpaiu5aK64lZZvZaFcfr6pPzeq1oYx+V7VfG6lCk79ZQrypYq6ipV1F8O9xZXe5bjenlDPcF1epGjr1terWhuqdSwF7nUwqHVnikUYDOhYc4WCbaZYNsiQ6MzwU+rCdL5bCb03GrrrtWx2rK1EL4FQ6uBTuMFt7ApDEi7D23staWVZHNfq2Uzb6+X1RutJPnz0zc0jJheWLQOQRTA1rHgt7iFXZKOJHuNc35IivayRoPrwrQ0+4703iv+dqnBMM7CQLJVhsP3es3mr+2zb5eFcapBsTAoDYxc3e04WrgwjsnvL85ts2+ncsVptcEw6IbDo+sMqa6+Lq58pVTR7OLqFWUrkeO3UnQecHROcXSBrfDc4egCWzW9y0nnJ0fLWZgLWJfNS337/MdGVH/hm3QRp8tvS4sv++UrxcbHLgwkm/saLaMXFoFt9lUPAE2Y+cGoq08avDH561YWkq0yXJyUpn7tP280FCrfHx9U871Bz+QmL4zTs0fafZiFcVKS8UwZL6NCbvv0ADjngrnEtb3F0d7g6hzklVJt3fBiWuH5ySuROczr85PLWlgp6fJiRauhY4UX5lopVdTCzmNlPatdeCtmoazYBbYiZfm4uhvZDpVtycJcwNWq+YXv1fTCziSb+7p4SZp5Z327YS9sPnlorQm49MK2G4IogM6Q7/EfAyPJX7O6VHuv10ZDiC+O+2UrxfiFcXZdz8I4aCkz84faZj1pG92tZaMLc0WHQSdZmGttX7B/YaGU/sJcdcJuLtKrG9eLHNeDnMuYMp6njCd5ZspmTJ6ZMp4p69kVZRnPlAlvh+qEy8NlXvVY0dcG+4Aa2bzUt9d/bIRz/ndho9Aa7pWdnZAmT/tlK3ONj901IHUPbDzI5nrohd2GCKIAUE+uIOWu90NkUs7xZYeOtl0X5qoJqZHgG1deDb5xwbbukOnQdnG5dEVZtMd6i25cEMtMV4TTTJJQGwnDdQNydTscrKuh2JOynhccS2uB3D+OV1NW71jh94ttV/hzRMJ93OfJegT3q2bm/8K1q39jI5ek4H7pdVYkjiubeWe9vNG6Bpn8BntfQ3NhM8SlzcKfLAC0EiEU2HY8z1TYhkOrS6He41LFqRKUVSr+vYfXysr+z3BZuRJ6OKdyOfgZlFVfV3b+8eJet/ZeLv5YlUgbmh5rrY0VLZecyk4qVyoqVxS8zh+6XapUVKmo8eeptPb+y61WDe7VsBoO7vVCbTSgNwrIa+E8UhYX0mvaECmLDdYJQnr9dsV/npqyyC8wNhTcM7lr6IWdbzyMOFw+e16afM3fbtoLu8sPqU1XJI6U5Xv5P0ETBFEAAICUmZlyGX8ObO82Glq9nTjnVHFaD9bVkJ0gDMeF2urr6oXtcDhvFvg3cqz1QB4+Vm3ZcqmsstP6MULBPdHn2ebBXVJNGG40VDyurPZ18SG9WpYxKeMNBI8jobDtKdMref21bcirrEKlqJ7yrLrLc+op+T+7SrMqlGbVVZpV1+qsulZnlJ+ZVf7i28qtzCq3OiOv3i3oJFW8nCpdgyp3DahSGJTrGpTrHpQrrIdWCwKs17NbXs918np2K9OzW152e4wm2WwEUQAAAGw7ZtVQQa9SUhsO1pX6ZeGgGw7I4TqxveqhY1UibWh0rHIlCOc1IX29bC2cV6TFcrn2s13l57lSTtJQ8GjGqUfLGtC8Bq2oQStql+Y1aPMaVFEDNq/BlaJ2Ff3ng3ZRA5rXgBW1yxYbHnnOdWtGvZpRn2bVpzn1adb6NGd9mrN+zVufil6/il6/HvrMf9aRgxu8p+02QRAFAAAA2oDnmfIE98SuNljXC9vlyPD6ap3pitNUqKxSXlV2ZVaZ5RllV2aUX51RduWy8iuzyq/OqGt1VvmS3ws7UppVofSeCqVZdZdnla3OhQ0WJr6w9O8lEUQBAAAAYEfwPJMn0zaaPt6Yc/4t4kJzX/ce/J2tbtVVI4gCAAAAwHZn5i+ClO/d2O3otilvqxsAAAAAAOgsBFEAAAAAQKoIogAAAACAVBFEAQAAAACpIogCAAAAAFJFEAUAAAAApIogCgAAAABIFUEUAAAAAJAqgigAAAAAIFUEUQAAAABAqgiiAAAAAIBUEUQBAAAAAKlqGkTN7HEzmzKz03X2321mM2b2YvD4cuubCQAAAABoF9kEdb4r6RFJTzSo8zPn3Cda0iIAAAAAQFtr2iPqnHtW0vsptAUAAAAA0AFaNUf0o2b2kpn9yMxurVfJzB40s5NmdvLChQstemsAAAAAwE7SiiD6vKRDzrnflfRNSf9Yr6Jz7jHn3Jhzbmzv3r0teGsAAAAAwE5zzUHUOTfrnCsG209JypnZnmtuGQAAAACgLV1zEDWz/WZmwfbx4JjT13pcAAAAAEB7arpqrpn9raS7Je0xswlJX5GUkyTn3KOSHpD0OTMrSVqUdMI55zatxQAAAACAHa1pEHXOfarJ/kfk394FAAAAAICmWrVqLgAAAAAAiRBEAQAAAACpIogCAAAAAFJFEAUAAAAApIogCgAAAABIFUEUAAAAAJAqgigAAAAAIFUEUQAAAABAqgiiAAAAAIBUEUQBAAAAAKkiiAIAAAAAUkUQBQAAAACkiiAKAAAAAEgVQRQAAAAAkCqCKAAAAAAgVQRRAAAAAECqCKIAAAAAgFQRRAEAAAAAqSKIAgAAAABSRRAFAAAAAKSKIAoAAAAASBVBFAAAAACQKoIoAAAAACBVBFEAAAAAQKoIogAAAACAVBFEAQAAAACpIogCAAAAAFJFEAUAAAAApIogCgAAAABIFUEUAAAAAJAqgigAAAAAIFUEUQAAAABAqgiiAAAAAIBUEUQBAAAAAKkiiAIAAAAAUtU0iJrZ42Y2ZWan6+w3M/srMztnZi+b2e+1vpkAAAAAgHaRpEf0u5LuabD/XknHgseDkr5z7c0CAAAAALSrpkHUOfespPcbVLlf0hPO90tJg2Z2oFUNBAAAAAC0l1bMET0o6e3Q84mg7Apm9qCZnTSzkxcuXGjBWwMAAAAAdppWBFGLKXNxFZ1zjznnxpxzY3v37m3BWwMAAAAAdppWBNEJSTeEno9IOt+C4wIAAAAA2lArguiTkj4TrJ57p6QZ59y7LTguAAAAAKANZZtVMLO/lXS3pD1mNiHpK5JykuSce1TSU5Luk3RO0oKkz25WYwEAAAAAO1/TIOqc+1ST/U7Sf21ZiwAAAAAAba0VQ3MBAAAAAEiMIAoAAAAASBVBFAAAAACQKoIoAAAAACBVBFEAAAAAQKoIogAAAACAVBFEAQAAAACpIogCAAAAAFJFEAUAAAAApIogCgAAAABIFUEUAAAAAJAqgigAAAAAIFUEUQAAAABAqgiiAAAAAIBUEUQBAAAAAKkiiAIAAAAAUkUQBQAAAACkiiAKAAAAAEgVQRQAAAAAkCqCKAAAAAAgVQRRAAAAAECqCKIAAAAAgFQRRAEAAAAAqSKIAgAAAABSRRAFAAAAAKSKIAoAAAAASBVBFAAAAACQKoIoAAAAACBVBFEAAAAAQKoIogAAAACAVBFEAQAAAACpIogCAAAAAFJFEAUAAAAApIogCgAAAABIVaIgamb3mNnrZnbOzB6O2X+3mc2Y2YvB48utbyoAAAAAoB1km1Uws4ykb0n6A0kTkp4zsyedc69Fqv7MOfeJTWgjAAAAAKCNJOkRPS7pnHPuN865FUnfl3T/5jYLAAAAANCukgTRg5LeDj2fCMqiPmpmL5nZj8zs1rgDmdmDZnbSzE5euHDhKpoLAAAAANjpkgRRiylzkefPSzrknPtdSd+U9I9xB3LOPeacG3POje3du3djLQUAAAAAtIUkQXRC0g2h5yOSzocrOOdmnXPFYPspSTkz29OyVgIAAAAA2kaSIPqcpGNmdsTM8pJOSHoyXMHM9puZBdvHg+NOt7qxAAAAAICdr+mquc65kpl9XtKPJWUkPe6ce9XM/iTY/6ikByR9zsxKkhYlnXDORYfvAgAAAAAg26q8ODY25k6ePLkl7w0AAAAA2Fxmdso5Nxa3LyrKEKkAAAYUSURBVMnQXAAAAAAAWoYgCgAAAABIFUEUAAAAAJAqgigAAAAAIFUEUQAAAABAqgiiAAAAAIBUEUQBAAAAAKkiiAIAAAAAUkUQBQAAAACkiiAKAAAAAEgVQRQAAAAAkCqCKAAAAAAgVQRRAAAAAECqCKIAAAAAgFQRRAEAAAAAqSKIAgAAAABSRRAFAAAAAKSKIAoAAAAASBVBFAAAAACQKoIoAAAAACBVBFEAAAAAQKoIogAAAACAVBFEAQAAAACpIogCAAAAAFJFEAUAAAAApIogCgAAAABIFUEUAAAAAJAqgigAAAAAIFUEUQAAAABAqgiiAAAAAIBUEUQBAAAAAKkiiAIAAAAAUkUQBQAAAACkiiAKAAAAAEgVQRQAAAAAkKpEQdTM7jGz183snJk9HLPfzOyvgv0vm9nvtb6pAAAAAIB20DSImllG0rck3SvpFkmfMrNbItXulXQseDwo6TstbicAAAAAoE0k6RE9Lumcc+43zrkVSd+XdH+kzv2SnnC+X0oaNLMDLW4rAAAAAKANZBPUOSjp7dDzCUm/n6DOQUnvhiuZ2YPye0wlqWhmr2+otenbI+niVjcC2xLnBurh3EAjnB+oh3MD9XBuoJ6dcG4cqrcjSRC1mDJ3FXXknHtM0mMJ3nNbMLOTzrmxrW4Hth/ODdTDuYFGOD9QD+cG6uHcQD07/dxIMjR3QtINoecjks5fRR0AAAAAABIF0eckHTOzI2aWl3RC0pOROk9K+kyweu6dkmacc+9GDwQAAAAAQNOhuc65kpl9XtKPJWUkPe6ce9XM/iTY/6ikpyTdJ+mcpAVJn928JqdqxwwjRuo4N1AP5wYa4fxAPZwbqIdzA/Xs6HPDnLtiKicAAAAAAJsmydBcAAAAAABahiAKAAAAAEhVxwdRM7vHzF43s3Nm9nDMfjOzvwr2v2xmv7cV7cTWSHB+3G1mM2b2YvD48la0E+kys8fNbMrMTtfZz3WjgyU4P7hudCgzu8HMfmpmZ8zsVTP7s5g6XD86UMJzg2tHBzKzgpn9PzN7KTg3/iKmzo68biS5j2jbMrOMpG9J+gP5t6B5zsyedM69Fqp2r6RjweP3JX0n+Ik2l/D8kKSfOec+kXoDsZW+K+kRSU/U2c91o7N9V43PD4nrRqcqSfqic+55M+uXdMrMfsL/O6Bk54bEtaMTLUv6D865opnlJP3czH7knPtlqM6OvG50eo/ocUnnnHO/cc6tSPq+pPsjde6X9ITz/VLSoJkdSLuh2BJJzg90IOfcs5Leb1CF60YHS3B+oEM55951zj0fbM9JOiPpYKQa148OlPDcQAcKrgXF4GkueERXm92R141OD6IHJb0dej6hK//RJ6mD9pT07/6jwXCJH5nZrek0Ddsc1w00w3Wjw5nZYUl3SPpVZBfXjw7X4NyQuHZ0JDPLmNmLkqYk/cQ51xbXjY4emivJYsqiv2FIUgftKcnf/fOSDgXDJe6T9I/yh0Wgs3HdQCNcNzqcmfVJ+qGkh5xzs9HdMS/h+tEhmpwbXDs6lHOuLOl2MxuU9A9mdptzLrwOwY68bnR6j+iEpBtCz0cknb+KOmhPTf/unXOz1eESzrmnJOXMbE96TcQ2xXUDdXHd6GzBHK8fSvqec+7vY6pw/ehQzc4Nrh1wzl2W9LSkeyK7duR1o9OD6HOSjpnZETPLSzoh6clInSclfSZYjepOSTPOuXfTbii2RNPzw8z2m5kF28fl/5uaTr2l2G64bqAurhudK/h7/2tJZ5xz36hTjetHB0pybnDt6ExmtjfoCZWZdUv6j5J+Ham2I68bHT001zlXMrPPS/qxpIykx51zr5rZnwT7H5X0lKT7JJ2TtCDps1vVXqQr4fnxgKTPmVlJ0qKkE865bT8UAtfGzP5W0t2S9pjZhKSvyF88gOsGkpwfXDc6112S/kjSK8F8L0n6c0k3Slw/OlySc4NrR2c6IOl/BHdz8CT9wDn3z+2QV4zzFwAAAACQpk4fmgsAAAAASBlBFAAAAACQKoIoAAAAACBVBFEAAAAAQKoIogAAAACAVBFEAQAAAACpIogCAAAAAFL1/wFh/6Pr7D3qvAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " epoch: 2/3 246.61s loss: trn=1.7121 val=1.7051 perpl: 5.54 acc: trn=0.4958 val=0.4944\n", "v:1.0350 g:0.0000 emb.weight : 350 shape: (35, 10) \n", "v:0.4388 g:0.0000 rnn.rnn.weight_ih_l0 : 7500 shape: (750, 10) \n", "v:0.1397 g:0.0000 rnn.rnn.weight_hh_l0 : 187500 shape: (750, 250) \n", "v:0.2872 g:0.0000 rnn.rnn.bias_ih_l0 : 750 shape: (750,) \n", "v:0.2143 g:0.0000 rnn.rnn.bias_hh_l0 : 750 shape: (750,) \n", "v:0.2661 g:0.0000 fc.weight : 8750 shape: (35, 250) \n", "v:0.5275 g:0.0000 fc.bias : 35 shape: (35,) \n", " trn 222.9s 992:43% loss: 1.6882 acc: 0.4996" ] } ], "source": [ "from IPython.display import clear_output\n", "import matplotlib.pyplot as plt\n", "\n", "model.rnn.mode = 1\n", "\n", "epochs =200 # число эпох\n", "for epoch in range(epochs): # эпоха - проход по всем примерам\n", " #optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))\n", " optimizer = torch.optim.SGD(model.parameters(), lr=1, momentum=0.8)\n", " \n", " beg = tm()\n", " \n", " idx = torch.randperm( len(X_trn) ) # перемешанный список индексов\n", " X_trn = X_trn[idx]\n", " Y_trn = Y_trn[idx]\n", " \n", " L_trn, A_trn = fit(model, X_trn, Y_trn, batch_size=256, NUM = NUM, train=True ) \n", " L_val, A_val = fit(model, X_val, Y_val, batch_size=256, NUM = 1, train=False) \n", " losses.append([L_trn, L_val])\n", " \n", " if epoch % 1 == 0 or it == epochs-1: \n", " clear_output(wait=True)\n", "\n", " plt.figure(figsize=(16,4)); plt.plot(losses); plt.ylim(0, 2.5); plt.legend(['trn', 'val']); plt.show()\n", " \n", " if model.rnn.mode != 1:\n", " fig, ax1 = plt.subplots()\n", " fig.set_size_inches(16,4)\n", " ax2 = ax1.twinx() \n", " ax1.set_ylabel(r'$g^{(h)}$', color='b', fontsize=18); ax2.set_ylabel('h', color='g', fontsize=18)\n", " ax2.plot(model.rnn.V_h.numpy()[1:], marker='.', linestyle='--', color=\"green\"); ax2.set_ylim(0, None)\n", " ax1.plot(model.rnn.G_h.numpy()[1:], marker='o', linewidth = 3, color=\"blue\"); ax1.set_ylim(0, None); \n", " ax1.plot(model.rnn.G_y.numpy(), marker='.', color=\"blue\"); ax1.set_ylim(0, None); \n", " plt.show(); \n", " \n", " print('\\r', f'epoch: {epoch:d}/{len(losses)-1} {tm()-beg:.2f}s loss: trn={L_trn:.4f} val={L_val:.4f} ', \n", " f'perpl: {np.exp(L_trn):5.2f} acc: trn={A_trn:.4f} val={A_val:.4f}' ) \n", " #f'acc: {A:.4f} {(oks*tot).sum():.4f} {oks.mean():.4f} {len(oks[oks > 1/len(oks)])/len(oks):.2f}' ) \n", " \n", " params_stat(model)\n", " \n", " #for i,c in enumerate(CHARS):\n", " # print(f\"{c}: {oks[i]:.3f}; {tot[i]:.3f}\") " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Результаты\n", "```\n", " mode=1 mode=2 mode=3\n", "26s -> 29.16 25.5 _> 28.82 29.2 -> 33.24 53.3 -> 57.6\n", "24.7s -> 27.86 25, 28.7 28.5 -> 32.51 \n", "25. -> 28.28 28.6 -> 32.31\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```\n", "LENGTH, STEP = 100, 1\n", "E_DIM, H_DIM, NUM_LAYERS, DROP = 10, 100, 1, 0.2\n", "epoch: 19/19 63.26s loss: trn=1.5113 val=1.3357 (last: 1.5091) perpl: 4.53 acc: 0.5208 0.5208 0.5094 1.00\n", "\n", "LENGTH, STEP = 100, 1\n", "E_DIM, H_DIM, NUM_LAYERS, DROP = 10, 100, 1, 0 GRU\n", "epoch: 19/19 55.01s loss: trn=0.9968 val=0.9993 perpl: 2.71 acc: trn=0.7358 val=0.7358\n", "epoch: 52/125 55.95s loss: trn=0.1353 val=0.1353 perpl: 1.14 acc: trn=0.9743 val=0.9743\n", "\n", "LENGTH, STEP = 100, 1 \n", "E_DIM, H_DIM, NUM_LAYERS, DROP = 10, 100, 1, 0 LSTM\n", "epoch: 21/21 56.09s loss: trn=1.2782 val=1.2694 perpl: 3.59 acc: trn=0.6559 val=0.6559\n", "epoch: 40/62 57.16s loss: trn=0.4476 val=0.4318 perpl: 1.56 acc: trn=0.8810 val=0.8810\n", "epoch: 63/126 57.46s loss: trn=0.1022 val=0.0880 perpl: 1.11 acc: trn=0.9828 val=0.9828\n", "\n", "LENGTH, STEP = 100, 1 \n", "E_DIM, H_DIM, NUM_LAYERS, DROP = 10, 100, 1, 0 GRU NUM = 90\n", "epoch: 25/25 35.94s loss: trn=0.6280 val=0.5689 perpl: 1.87 acc: trn=0.8474 val=0.8474\n", "epoch: 32/58 37.74s loss: trn=0.2546 val=0.1764 perpl: 1.29 acc: trn=0.9609 val=0.9609\n", "epoch: 70/129 37.24s loss: trn=0.0953 val=0.0427 perpl: 1.10 acc: trn=0.9934 val=0.9934\n", "epoch: 0/156 37.74s loss: trn=0.0746 val=0.0319 perpl: 1.08 acc: trn=0.9947 val=0.9956\n", "\n", "```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Генерация текста" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "st = \"три девицы под окном пряли \"\n", "#st = \"кабы я была\"\n", "#st = \"говорит одна\"\n", "#st = \"то на весь крещ\"\n", "#st = \"уложили спать\"\n", "\n", "model.rnn.mode = 1\n", "def random_char(probs): \n", " return np.random.choice(len(probs), 1, p=probs)[0]\n", "\n", "x = torch.tensor( [ charID[ st[0] ] ], dtype=torch.long).view(1,1)\n", "h = torch.zeros (NUM_LAYERS, 1, H_DIM) \n", "for i in range(1000):\n", " y, h = model(x, h0 = h.detach(), NUM=1) # (B,C,1), (LAYERS,B,H) \n", " probs = nn.Softmax(dim=1)(y.detach()[:,:,-1])\n", " probs = probs.to(cpu).numpy().reshape(-1) \n", " idx = charID[ st[i] ] if i < len(st) else random_char(probs) \n", " x[0,0] = torch.tensor(idx, dtype=torch.long, device=gpu) \n", " print(CHARS[idx], end=\"\") " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Cохранение модели" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import datetime\n", " \n", "state = {'info': \"RNN генератор текста по буквам\", \n", " 'date': datetime.datetime.now(), \n", " 'parms': \"\", \n", " 'model' : model.state_dict(), \n", " 'optimizer': optimizer.state_dict()} \n", " \n", "torch.save(state, 'rnn_char_100_10_100_1_loss_0_xx.pt') " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "state = torch.load('rnn_char_100_10_100_1_loss_0_xx.pt') # загружаем файл\n", " \n", "model.load_state_dict(state['model']) # получаем параметры модели" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Tests" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tm1 = tm()\n", "\n", "y = torch.empty(100,256,256) \n", "for i in range(100):\n", " x = torch.randn(256,256, \n", " requires_grad=True)\n", " y[i] = x\n", " \n", "y.retain_grad() \n", "z = y.sum()\n", "z.backward()\n", " \n", "print(tm()-tm1 )\n", "print(y.grad_fn, y.grad.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tm1 = tm()\n", "\n", "y = [] \n", "for i in range(100):\n", " x = torch.randn(1,256,256, requires_grad=True)\n", " y.append( x ) \n", "y = torch.cat(y, dim=0)\n", "\n", "#x.retain_grad() \n", "z = y.sum()\n", "z.backward()\n", " \n", "print(tm()-tm1)\n", "print(y.grad_fn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torchviz\n", "from torch import tensor, empty, ones, zeros\n", "\n", "x, w, b = ones(5), ones(5, requires_grad=True), tensor(0., requires_grad=True)\n", "z = x.dot(w) + b\n", "\n", "torchviz.make_dot(z, params = {'x': x, 'w': w, 'b': b} )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "idx = torch.tensor([1,0,2,0])\n", "torch.zeros(len(idx),3).scatter_(1, idx.unsqueeze(1), 1.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "L, C = 3, 4\n", "\n", "inp = torch.tensor([ [1,0,1], [0,2,3] ], dtype=torch.long)\n", "one_hot = torch.zeros(len(inp), L, C).scatter_(2, inp.unsqueeze(2), 1.)\n", "print(one_hot)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "468px", "left": "1115.05px", "top": "66px", "width": "164.922px" }, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 4 }