Thanks to visit codestin.com
Credit goes to github.com

Skip to content

patches.Scatter() and patches.Plot() should be added. #29318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
hyperkai opened this issue Dec 15, 2024 · 3 comments
Closed

patches.Scatter() and patches.Plot() should be added. #29318

hyperkai opened this issue Dec 15, 2024 · 3 comments
Labels
Community support Users in need of help.

Comments

@hyperkai
Copy link

hyperkai commented Dec 15, 2024

Problem

Only using patches.Rectangle() of plt.subplots() works properly as shown below:

from torchvision.datasets import CelebA

my_data = CelebA(
    root="data",
    split="all",
    target_type=["bbox", "landmarks"]
)

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(12, 6))
for (im, ((x, y, w, h), lm)), axis in zip(my_data, axes.ravel()):
    axis.imshow(X=im)
    rect = Rectangle(xy=(x, y), width=w, height=h, linewidth=3, edgecolor='r', facecolor='none', clip_on=True) # Here
    axis.add_patch(p=rect)                                                                                     # Here

Screenshot 2024-12-16 012242

And only using axes.Axes.scatter() of plt.subplots() works properly as shown below:

from torchvision.datasets import CelebA
import torchvision

my_data = CelebA(
    root="data",
    split="all",
    target_type=["bbox", "landmarks"]
)

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.patches import Circle

fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(12, 6))
for (im, ((x, y, w, h), lm)), axis in zip(my_data, axes.ravel()):
    axis.imshow(X=im)

    for px, py in lm.split(2):                # Here
        axis.scatter(x=px, y=py, c='#1f77b4') # Here

Screenshot 2024-12-16 020807

And only using axes.Axes.plot() of plt.subplots() works properly as shown below:

from torchvision.datasets import CelebA
import torchvision

my_data = CelebA(
    root="data",
    split="all",
    target_type=["bbox", "landmarks"]
)

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.patches import Circle

fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(12, 6))
for (im, ((x, y, w, h), lm)), axis in zip(my_data, axes.ravel()):
    axis.imshow(X=im)

    px = []                    # Here
    py = []                    # Here
    for j, v in enumerate(lm): # Here
        if j%2 == 0:           # Here
            px.append(v)       # Here
        else:                  # Here
            py.append(v)       # Here
    axis.plot(px, py)          # Here

Screenshot 2024-12-16 021239

But using patches.Rectangle() of plt.subplots() with axes.Axes.scatter() of plt.subplots() doesn't work properly as shown below:

from torchvision.datasets import CelebA

my_data = CelebA(
    root="data",
    split="all",
    target_type=["bbox", "landmarks"]
)

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(12, 6))
for (im, ((x, y, w, h), lm)), axis in zip(my_data, axes.ravel()):
    axis.imshow(X=im)
    rect = Rectangle(xy=(x, y), width=w, height=h, linewidth=3, edgecolor='r', facecolor='none', clip_on=True) # Here
    axis.add_patch(p=rect)                                                                                     # Here

    for px, py in lm.split(2):                # Here
        axis.scatter(x=px, y=py, c='#1f77b4') # Here

Screenshot 2024-12-16 012506

And using patches.Rectangle() of plt.subplots() with axes.Axes.plot() of plt.subplots() doesn't work properly as shown below:

from torchvision.datasets import CelebA

my_data = CelebA(
    root="data",
    split="all",
    target_type=["bbox", "landmarks"]
)

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(12, 6))
for (im, ((x, y, w, h), lm)), axis in zip(my_data, axes.ravel()):
    axis.imshow(X=im)
    rect = Rectangle(xy=(x, y), width=w, height=h, linewidth=3, edgecolor='r', facecolor='none', clip_on=True) # Here
    axis.add_patch(p=rect)                                                                                     # Here

    px = []                    # Here
    py = []                    # Here
    for j, v in enumerate(lm): # Here
        if j%2 == 0:           # Here
            px.append(v)       # Here
        else:                  # Here
            py.append(v)       # Here
    axis.plot(px, py)          # Here

Screenshot 2024-12-16 013942

So instead, I used patches.Rectangle() of plt.subplots() with patches.Circle() of plt.subplots(), then they properly work as shown below:

from torchvision.datasets import CelebA

my_data = CelebA(
    root="data",
    split="all",
    target_type=["bbox", "landmarks"]
)

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.patches import Circle

fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(12, 6))
for (im, ((x, y, w, h), lm)), axis in zip(my_data, axes.ravel()):
    axis.imshow(X=im)
    rect = Rectangle(xy=(x, y), width=w, height=h, linewidth=3, edgecolor='r', facecolor='none', clip_on=True) # Here
    axis.add_patch(p=rect)                                                                                     # Here

    for px, py in lm.split(2):                # Here
        axis.add_patch(p=Circle(xy=(px, py))) # Here

Screenshot 2024-12-16 014626

Proposed solution

So, it seems like matplotlib.patches works with matplotlib.patches so patches.Scatter() and patches.Plot() should be added.

@jklymak
Copy link
Member

jklymak commented Dec 15, 2024

They are all working the same, just scatter is calling autolim, and add_patch is not.

@timhoffm
Copy link
Member

What is happening:

  • add_patch() is a lower-level function and updates the data limits (Axes.dataLim) to include the Patch, but doesn't touch the view limits (Axes.viewLim).
  • scatter() is a high-level function that additionally updates the view limits based on the data limits (essentially through a lazy autoscale_view()), which then picks up the rectangle limits as well.

If you want to have limits exactly on the image, no matter what you plot afterwards you should call ax.autoscale(False) after ax.imshow().

@timhoffm timhoffm added Community support Users in need of help. and removed New feature labels Dec 17, 2024
@timhoffm
Copy link
Member

I'm closing as the solution for the present case is to deactivate autoscaling.

Independent of this, there are some ideas how to change/improve autoscaling. These are independently tracked, e.g. in

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Community support Users in need of help.
Projects
None yet
Development

No branches or pull requests

3 participants