前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TensorFlow 2.0 - Hub 模型复用

TensorFlow 2.0 - Hub 模型复用

作者头像
Michael阿明
发布2021-09-06 16:01:46
9060
发布2021-09-06 16:01:46
举报

?

学习于:简单粗暴 TensorFlow 2

1. tfhub

网址:

https://hub.tensorflow.google.cn/

https://tfhub.dev/

  • 可以搜索,下载模型
  • 安装包 pip install tensorflow-hub
代码语言:javascript
复制
import tensorflow_hub as hub

hub_url = 'https://hub.tensorflow.google.cn/google/magenta/arbitrary-image-stylization-v1-256/2'
hub_model = hub.load(hub_url) # 加载模型
outputs = hub_model(inputs) # 调用模型

2. 例子:神经风格转换

Ng 课也讲过这个例子

代码语言:javascript
复制
import tensorflow_hub as hub
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf


# 归一化,resize
def load_image_local(img_path, img_size=(256, 256)):
    # png 4 通道转 jpg 3通道
    if 'png' in img_path:
        img = Image.open(img_path)
        img = img.convert('RGB')
        img.save("temp.jpg")
        img = plt.imread("temp.jpg").astype(np.float32)[np.newaxis, :, :, :]
    else:
        # 添加一个 batch_size 轴
        img = plt.imread(img_path).astype(np.float32)[np.newaxis, :, :, :]
    if img.max() > 1.0:
        img = img / 255.
    img = tf.image.resize(img, img_size, preserve_aspect_ratio=True)
    return img


# 绘制图片
def show_image(img, title, save=False, fig_dpi=300):
    plt.imshow(img, aspect='equal')
    plt.axis('off')
    plt.show()
    if save:
        plt.imsave(title + '.jpg', img.numpy())


# 图片路径
content_image_path = "pic1.jpg"
style_image_path = "pic2.jpg"

# 处理图片
content_image = load_image_local(content_image_path)
style_image = load_image_local(style_image_path)

# 展示图片
show_image(content_image[0], "Content Image")
show_image(style_image[0], "Style Image")

# 加载模型
hub_url = 'https://hub.tensorflow.google.cn/google/magenta/arbitrary-image-stylization-v1-256/2'
hub_model = hub.load(hub_url)

# 调用模型
outputs = hub_model(tf.constant(content_image), tf.constant(style_image))
stylized_image = outputs[0]  # 取出第一个样本预测值 [ :, :, 3]

# 展示预测图片
show_image(stylized_image[0], "Stylized Image", True)

内容图片:

风格图片:

转换后的图片:

3. retrain 例子

https://hub.tensorflow.google.cn/google/imagenet/inception_v3/feature_vector/4

  • hub.KerasLayer(url) 封装一个layer到模型当中,可以设置是否 finetune
代码语言:javascript
复制
num_classes = 10
model = tf.keras.Sequential([
    hub.KerasLayer("https://hub.tensorflow.google.cn/google/imagenet/inception_v3/feature_vector/4",
                   trainable=False),  # 可以设为True,微调
    tf.keras.layers.Dense(num_classes, activation='softmax')
])
model.build([None, 299, 299, 3])  # Batch input shape
model.summary()

模型结构

代码语言:javascript
复制
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
keras_layer (KerasLayer)     (None, 2048)              21802784  
_________________________________________________________________
dense (Dense)                (None, 10)                20490     
=================================================================
Total params: 21,823,274
Trainable params: 20,490
Non-trainable params: 21,802,784
_________________________________________________________________
本文参与?腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2021-02-05 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客?前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与?腾讯云自媒体分享计划? ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • ?
  • 1. tfhub
  • 2. 例子:神经风格转换
  • 3. retrain 例子
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
http://www.vxiaotou.com