{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "<img width=\"800px\" src=\"../fidle/img/header.svg\"></img>\n", "\n", "# <!-- TITLE --> [K3IMDB4] - Reload embedded vectors\n", "<!-- DESC --> Retrieving embedded vectors from our trained model, using Keras 3 and PyTorch\n", "<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n", "\n", "## Objectives :\n", " - The objective is to retrieve and visualize our embedded vectors\n", " - For this, we will use our **previously saved model**.\n", "\n", "## What we're going to do :\n", "\n", " - Retrieve our saved model\n", " - Extract vectors and play with\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 1 - Init python stuff" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['KERAS_BACKEND'] = 'torch'\n", "\n", "import keras\n", "\n", "import json,re\n", "import numpy as np\n", "\n", "import fidle\n", "\n", "# Init Fidle environment\n", "run_id, run_dir, datasets_dir = fidle.init('K3IMDB4')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.2 - Parameters\n", "The words in the vocabulary are classified from the most frequent to the rarest. \n", "`vocab_size` is the number of words we will remember in our vocabulary (the other words will be considered as unknown). \n", "`review_len` is the review length \n", "`saved_models` where our models were previously saved \n", "`dictionaries_dir` is where we will go to save our dictionaries. (./data is a good choice)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vocab_size = 5000\n", "review_len = 256\n", "\n", "saved_models = './run/K3IMDB2'\n", "dictionaries_dir = './data'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Override parameters (batch mode) - Just forget this cell" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fidle.override('vocab_size', 'review_len', 'saved_models', 'dictionaries_dir')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 2 - Get the embedding vectors !" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.1 - Load model and dictionaries" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = keras.models.load_model(f'{saved_models}/models/best_model.keras')\n", "print('Model loaded.')\n", "\n", "with open(f'{dictionaries_dir}/word_index.json', 'r') as fp:\n", " word_index = json.load(fp)\n", " index_word = { i:w for w,i in word_index.items() }\n", " print('Dictionaries loaded. ', len(word_index), 'entries' )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.2 - Retrieve embeddings" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "embeddings = model.layers[0].get_weights()[0]\n", "print('Shape of embeddings : ',embeddings.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.3 - Build a nice dictionary" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "word_embedding = { index_word[i]:embeddings[i] for i in range(vocab_size) }" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 3 - Have a look !\n", "#### Show embedding of a word :" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "word_embedding['nice']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Few usefull functions to play with" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Return a l2 distance between 2 words\n", "#\n", "def l2w(w1,w2):\n", " v1=word_embedding[w1]\n", " v2=word_embedding[w2]\n", " return np.linalg.norm(v2-v1)\n", "\n", "# Show distance between 2 words \n", "#\n", "def show_l2(w1,w2):\n", " print(f'\\nL2 between [{w1}] and [{w2}] : ',l2w(w1,w2))\n", "\n", "# Displays the 15 closest words to a given word\n", "#\n", "def neighbors(w1):\n", " v1=word_embedding[w1]\n", " dd={}\n", " for i in range(4, 1000):\n", " w2=index_word[i]\n", " dd[w2]=l2w(w1,w2)\n", " dd= {k: v for k, v in sorted(dd.items(), key=lambda item: item[1])}\n", " print(f'\\nNeighbors of [{w1}] : ', list(dd.keys())[1:15])\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Examples" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_l2('nice', 'pleasant')\n", "show_l2('nice', 'horrible')\n", "\n", "neighbors('horrible')\n", "neighbors('great')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fidle.end()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "<img width=\"80px\" src=\"../fidle/img/logo-paysage.svg\"></img>" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.2 ('fidle-env')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" }, "vscode": { "interpreter": { "hash": "b3929042cc22c1274d74e3e946c52b845b57cb6d84f2d591ffe0519b38e4896d" } } }, "nbformat": 4, "nbformat_minor": 4 }