Skip to content
Snippets Groups Projects
04-Show-vectors.ipynb 6.28 KiB
Newer Older
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img width=\"800px\" src=\"../fidle/img/header.svg\"></img>\n",
    "# <!-- TITLE --> [K3IMDB4] - Reload embedded vectors\n",
    "<!-- DESC --> Retrieving embedded vectors from our trained model, using Keras 3 and PyTorch\n",
    "<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n",
    "\n",
    "## Objectives :\n",
    " - The objective is to retrieve and visualize our embedded vectors\n",
    " - For this, we will use our **previously saved model**.\n",
    "\n",
    "## What we're going to do :\n",
    "\n",
    " - Retrieve our saved model\n",
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
    " - Extract vectors and play with\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 - Init python stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['KERAS_BACKEND'] = 'torch'\n",
    "import json,re\n",
    "import numpy as np\n",
    "import fidle\n",
    "# Init Fidle environment\n",
    "run_id, run_dir, datasets_dir = fidle.init('K3IMDB4')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 - Parameters\n",
    "The words in the vocabulary are classified from the most frequent to the rarest.  \n",
    "`vocab_size` is the number of words we will remember in our vocabulary (the other words will be considered as unknown).  \n",
    "`review_len` is the review length  \n",
    "`saved_models` where our models were previously saved  \n",
    "`dictionaries_dir` is where we will go to save our dictionaries. (./data is a good choice)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab_size           = 5000\n",
    "review_len           = 256\n",
    "\n",
    "saved_models         = './run/K3IMDB2'\n",
    "dictionaries_dir     = './data'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Override parameters (batch mode) - Just forget this cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fidle.override('vocab_size', 'review_len', 'saved_models', 'dictionaries_dir')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - Get the embedding vectors !"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.1 - Load model and dictionaries\n",
    "Note : This dictionary is generated by [02-Embedding-Keras](02-Keras-embedding.ipynb) notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.models.load_model(f'{saved_models}/models/best_model.keras')\n",
    "print('Model loaded.')\n",
    "\n",
    "with open(f'{dictionaries_dir}/word_index.json', 'r') as fp:\n",
    "    word_index = json.load(fp)\n",
    "    index_word = { i:w      for w,i in word_index.items() }\n",
    "    print('Dictionaries loaded. ', len(word_index), 'entries' )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 - Retrieve embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = model.layers[0].get_weights()[0]\n",
    "print('Shape of embeddings : ',embeddings.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.3 - Build a nice dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "word_embedding = { index_word[i]:embeddings[i] for i in range(vocab_size) }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
    "## Step 3 - Have a look !\n",
    "#### Show embedding of a word :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "word_embedding['nice']"
   ]
  },
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Few usefull functions to play with"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
    "# Return a l2 distance between 2 words\n",
    "#\n",
    "def l2w(w1,w2):\n",
    "    v1=word_embedding[w1]\n",
    "    v2=word_embedding[w2]\n",
    "    return np.linalg.norm(v2-v1)\n",
    "\n",
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
    "# Show distance between 2 words \n",
    "#\n",
    "def show_l2(w1,w2):\n",
    "    print(f'\\nL2 between [{w1}] and [{w2}] : ',l2w(w1,w2))\n",
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
    "\n",
    "# Displays the 15 closest words to a given word\n",
    "#\n",
    "def neighbors(w1):\n",
    "    v1=word_embedding[w1]\n",
    "    dd={}\n",
    "    for i in range(4, 1000):\n",
    "        w2=index_word[i]\n",
    "        dd[w2]=l2w(w1,w2)\n",
    "    dd= {k: v for k, v in sorted(dd.items(), key=lambda item: item[1])}\n",
    "    print(f'\\nNeighbors of [{w1}] : ', list(dd.keys())[1:15])\n",
    "    "
   ]
  },
Jean-Luc Parouty's avatar
Jean-Luc Parouty committed
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_l2('nice', 'pleasant')\n",
    "show_l2('nice', 'horrible')\n",
    "\n",
    "neighbors('horrible')\n",
    "neighbors('great')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fidle.end()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "<img width=\"80px\" src=\"../fidle/img/logo-paysage.svg\"></img>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.2 ('fidle-env')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  },
  "vscode": {
   "interpreter": {
    "hash": "b3929042cc22c1274d74e3e946c52b845b57cb6d84f2d591ffe0519b38e4896d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}