diff --git a/3. O'Reilly Generate from image.ipynb b/3. O'Reilly Generate from image.ipynb index 7d1b1da..df5810f 100644 --- a/3. O'Reilly Generate from image.ipynb +++ b/3. O'Reilly Generate from image.ipynb @@ -287,6 +287,7 @@ " image = cv2.imread(x)\n", " if as_float:\n", " image = image.astype(np.float32)\n", + " image = image / image.max()\n", "\n", " if len(image.shape) == 2:\n", " image = np.tile(image[:,:,None], 3)\n", @@ -351,7 +352,7 @@ }, "outputs": [], "source": [ - "def test(sess,image,generated_words,ixtoword,test_image_path=0): # Naive greedy search\n", + "def test(sess,image,generated_words,ixtoword,test_image_path=image_path): # Naive greedy search\n", "\n", " \n", "\n", @@ -370,9 +371,11 @@ " generated_word_index= sess.run(generated_words, feed_dict={image:fc7})\n", " generated_word_index = np.hstack(generated_word_index)\n", " generated_words = [ixtoword[x] for x in generated_word_index]\n", - " punctuation = np.argmax(np.array(generated_words) == '.')+1\n", - "\n", - " generated_words = generated_words[:punctuation]\n", + " \n", + " # Select words until the first period, or all words if there is no period\n", + " punctuation = np.argmax(np.array(generated_words) == '.')\n", + " if punctuation == 0: punctuation = len(generated_words)\n", + " generated_words = generated_words[:punctuation+1]\n", " generated_sentence = ' '.join(generated_words)\n", " print(generated_sentence)" ]