#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import numpy as np

import tripleblind as tb


tb.util.set_script_dir_current()

# Stopped supporting Keras in this example due to compatibility issues with the
# latest version of Keras.
#
# try:
#     from keras import layers, models
#     from tensorflow import keras
#
#     print()
#     print("Creating keras model...")
#     print()
#     keras_model = models.Sequential(
#         [
#             layers.Conv2D(8, (5, 5), activation="relu", input_shape=(32, 32, 3)),
#             layers.MaxPooling2D((2, 2)),
#             layers.Conv2D(16, (5, 5), activation="relu"),
#             layers.MaxPooling2D((2, 2)),
#             layers.Conv2D(64, (5, 5), activation="relu"),
#             layers.Flatten(),
#             layers.Dense(64, activation="relu"),
#             layers.Dense(10),
#         ]
#     )
#     keras_model.build((1, 32, 32, 3))
#     keras_model.save("model.h5", save_format="h5")
#     print("Saved model.h5")
# except ImportError:
#     print("keras not found, skipping model creation...")

try:
    import torch
    import torch.nn.functional as F
    from torch import nn

    print()
    print("Creating pytorch model...")
    print()
    pytorch_model = nn.Sequential(
        nn.Conv2d(3, 8, 5),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Conv2d(8, 16, 5),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Flatten(),
        nn.Linear(16 * 5 * 5, 64),
        nn.ReLU(),
        nn.Linear(64, 10),
    )
    torch.save(pytorch_model, "model.pth")
    print("Saved model.pth")
except ImportError:
    print("pytorch not found, skipping model creation...")

try:
    import onnx

    print()
    print("Creating onnx model...")
    print()
    X = np.random.random((1, 3, 32, 32)).astype(np.float32)
    with open("model.onnx", "wb") as f:
        torch.onnx.export(
            pytorch_model, args=torch.from_numpy(X), f=f, opset_version=15
        )
        print("Saved model.onnx")
except UnboundLocalError:
    print("pytorch model must be created first for onnx model creation.")
except ImportError:
    print("Unable to load onnx (install with 'pip install onnx')")
    print("Skipping model creation.")
