You signed in with another tab or window.
Reload
to refresh your session.
You signed out in another tab or window.
Reload
to refresh your session.
You switched accounts on another tab or window.
Reload
to refresh your session.
By clicking “Sign up for GitHub”, you agree to our
terms of service
and
privacy statement
. We’ll occasionally send you account related emails.
Already on GitHub?
Sign in
to your account
@likethesky
@Celebio
@colesbury
@pdollar
@minqi
Hi, I have the error at this line when I use multiple points in an image to extract the segments.
My input points and labels used
points = tensor([[ 711, 455], [1578, 611], [2019, 640], [ 412, 739], [1820, 810]], device='cuda:0')
label = tensor([1, 1, 1, 1, 1], device='cuda:0')
for prediction:
masks, _, _ = predictor.predict_torch( point_coords=points, point_labels=label, multimask_output=False, )
segment-anything/segment_anything/modeling/prompt_encoder.py
Line 84
6fdee8f
I think the problem is due to the
predict_torch(...)
function expecting there to be a batch dimension.
One way to fix the problem is to add in the extra dimension that the function expects. You can do this using the
unsqueeze(...)
method on the tensors:
x = 0 # or 1, if the 5 points/labels are meant to be separate prompts
points = tensor([[ 711, 455], [1578, 611], [2019, 640], [ 412, 739], [1820, 810]], device='cuda:0').unsqueeze(x)
label = tensor([1, 1, 1, 1, 1], device='cuda:0').unsqueeze(x)
If you use
unsqueeze(1)
then that is interpreted as meaning that the 5 points/labels are meant to generate 5 separate masks (and you'll get 5 masks as the output). Otherwise, if you use
unsqueeze(0)
, then the 5 points/labels are all interpreted as belonging to a single segmentation prompt, and you'll get 1 mask as an output.
Another solution, if you do mean to only get 1 mask, is to use the
predict(...)
function (i.e. not the
_torch
variant), which may be simpler (although it's going to do some extra pre-processing that you may not need if you already did that with your existing point coordinates). It would look something like:
import numpy as np # Need this to make np.array inputs!
points = [[ 711, 455], [1578, 611], [2019, 640], [ 412, 739], [1820, 810]]
label = [1, 1, 1, 1, 1]
masks, _, _ = predictor.predict(
point_coords = np.array(points),
point_labels = np.array(label),
multimask_output=False)