2021年3月29日星期一

Error with matplotlib scatter plot due to color palette

I used this method to crate a scatter plot for another model for mnist dataset, and it works fine for the other model and I cannot figure out what I did wrong with this other model.

The method is

def scatter(x, labels, subtitle=None):      # Create a scatter plot of all the       # the embeddings of the model.      # We choose a color palette with seaborn.      palette = np.array(sns.color_palette("hls", 10))      # We create a scatter plot.      f = plt.figure(figsize=(8, 8))      ax = plt.subplot(aspect='equal')      sc = ax.scatter(x[:,0], x[:,1], lw=0,alpha = 0.5, s=40,                  c=palette[labels.astype(np.int)])      plt.xlim(-25, 25)      plt.ylim(-25, 25)      ax.axis('off')      ax.axis('tight')  

I use this to create the data for the plot using the mnist dataset from keras

# Using the newly trained model compute the embeddings   # for a number images  sample_size = 5000  X_train_trm = model.predict(X_train[:sample_size].reshape(-1,28,28,1))  X_test_trm = model.predict(X_test[:sample_size].reshape(-1,28,28,1))  # TSNE to use dimensionality reduction to visulaise the resultant embeddings  tsne = TSNE()  train_tsne_embeds = tsne.fit_transform(X_train_trm)  scatter(train_tsne_embeds, y_train[:sample_size])  

This then gives this error which I do not understand when I check the size of the palette and c as well which should be 5000 and not 150000. The error is this

ValueError: 'c' argument has 150000 elements, which is inconsistent with 'x' and 'y' with size 5000.  
https://stackoverflow.com/questions/66863527/error-with-matplotlib-scatter-plot-due-to-color-palette March 30, 2021 at 09:06AM

没有评论:

发表评论