View Source Segment Anything Model (SAM) using Ortex
Mix.install([
{:ortex, "~> 0.1.9"},
{:image, "~> 0.47"},
{:nx_image, "~> 0.1.2"},
{:exla, "~> 0.7.2"},
{:kino, "~> 0.12.3"},
{:req, "~> 0.4"}
])
Recap
This is an implementation by @herisson of the Segment Anything Model (SAM) from facebook research using the Ortex library to run ONNX models.
It is approximately a port this jupyter notebook example to livebook: https://colab.research.google.com/drive/1wmjHHcrZ_s8iFuVFh9iHo6GbUS_xH5xq
It uses the onnx "mobile" models found here.
Nx.global_default_backend(EXLA.Backend)
Nx.default_backend()
Loading the models
encoder_url =
"https://huggingface.co/ginkgoo/SegmentAnythingModel-Elixir-Ortex/resolve/main/vision_encoder_quantized.onnx"
encoder_path = "/tmp/vision_encoder_quantized.onnx"
%{body: body} = Req.get!(encoder_url)
File.write!(encoder_path, body)
model = Ortex.load(encoder_path)
decoder_url =
"https://huggingface.co/ginkgoo/SegmentAnythingModel-Elixir-Ortex/resolve/main/vit_b_decoder.onnx"
decoder_path = "/tmp/vit_b_decoder.onnx"
%{body: body} = Req.get!(decoder_url)
File.write!(decoder_path, body)
decoder = Ortex.load(decoder_path)
Getting the image
image_input = Kino.Input.image("Uploaded Image")
# https://ibb.co/Hr8588Y the image I'm using
image =
Kino.Input.read(image_input)
|> Image.from_kino!()
# The demo image is already these dimensions
# but other images may not be.
resized =
Image.thumbnail!(image, "1024x1024")
original_label = Kino.Markdown.new("**Original image**")
resized_label = Kino.Markdown.new("**Resized image**")
Kino.Layout.grid(
[
Kino.Layout.grid([image, original_label], boxed: true),
Kino.Layout.grid([resized, resized_label], boxed: true)
],
columns: 2
)
# Send the image to Nx. This is a zero-copy
# operation - basically just give the pointer
# to a memory area to Nx.
tensor =
resized
|> Image.to_nx!()
|> Nx.as_type(:f32)
# Mean and std values copied from transformer.js
mean = Nx.tensor([123.675, 116.28, 103.53])
std = Nx.tensor([58.395, 57.12, 57.375])
normalized_tensor =
tensor
|> NxImage.normalize(mean, std)
|> Nx.tensor(names: [:height, :width, :bands])
|> Nx.transpose(axes: [:bands, :height, :width])
# Running image encoder / outputs depends on the onnx model
{image_embeddings, _} = Ortex.run(model, Nx.broadcast(normalized_tensor, {1, 3, 1024, 1024}))
# {image_embeddings} = Ortex.run(model, Nx.broadcast(normalized_tensor, {1024, 1024, 3}))
Prompt encoding & mask generation
# prepare inputs
# xy box coordinates in our image of the object we want to detour
# input_point = Nx.tensor([[345, 272], [640, 760]]) |> Nx.as_type(:f32) |> Nx.reshape({1, 2, 2})
# 2, 3 is for box startig / end points
# input_label = Nx.tensor([2, 3]) |> Nx.reshape({1, 2}) |> Nx.as_type(:f32)
# single point
input_point = Nx.tensor([[514, 514], [0, 0]]) |> Nx.as_type(:f32) |> Nx.reshape({1, 2, 2})
input_label = Nx.tensor([1, -1]) |> Nx.reshape({1, 2}) |> Nx.as_type(:f32)
# Filled with 0, not used here
mask_input = Nx.broadcast(0, {1, 1, 256, 256}) |> Nx.as_type(:f32)
# not using mask_input
has_mask = Nx.broadcast(0, 1) |> Nx.as_type(:f32)
original_image_dim = Nx.tensor([1024, 1024]) |> Nx.as_type(:f32)
{mask, scores, low_res} =
Ortex.run(decoder, {
Nx.broadcast(image_embeddings, {1, 256, 64, 64}),
Nx.broadcast(input_point, {1, 2, 2}),
Nx.broadcast(input_label, {1, 2}),
Nx.broadcast(mask_input, {1, 1, 256, 256}),
Nx.broadcast(has_mask, {1}),
Nx.broadcast(original_image_dim, {2})
})
Transform output to a black an white image
use Image.Math
mask =
mask[0][2]
|> Nx.reshape({1024, 1024, 1})
|> Image.from_nx!()
low_res =
low_res[0][2]
|> Nx.reshape({256, 256, 1})
|> Image.from_nx!()
# Threshold the masks to produce black and white
# This is one NIF call only. Then then take only the
# first band which we'll use as the alpha band.
mask = Image.if_then_else!(mask > 0, 255, 0)[0]
low_res_mask = Image.if_then_else!(mask > 0, 255, 0)[0]
mask_label = Kino.Markdown.new("**Mask**")
low_res_label = Kino.Markdown.new("**Low Res**")
Kino.Layout.grid(
[
Kino.Layout.grid([mask, mask_label], boxed: true),
Kino.Layout.grid([low_res_mask, low_res_label], boxed: true)
],
columns: 2
)
Showing the masks
As you can see the high res mask is distorted / badly placed. However, the low res mask fits corectly ??
new_image = Image.add_alpha!(image, mask)
mask_label = Kino.Markdown.new("**Image mask**")
new_image_label = Kino.Markdown.new("**new image**")
Kino.Layout.grid(
[
Kino.Layout.grid([image, original_label], boxed: true),
Kino.Layout.grid([mask, mask_label], boxed: true),
Kino.Layout.grid([new_image, new_image_label], boxed: true)
],
columns: 3
)