Change elements of an array based on a condition using np.where

Let’s say we want to convert multiple categorical variables into binary variables by selecting one category as “0” and the rest as “1”.

Or we want to change the values of an array based on a condition, such as in RELU function where all negative values are converted to zero and rest stay the same.

We can do this using np.where function.

Let’s take an array of letter from “A” to “E”. We want to have the letter “C” to be labelled as “0” and rest of the letter to be labelled as one. Following is ho we do it:

import numpy as np
# create array
a = np.array(['A', 'B', 'C', 'D', 'E'])

# convert to binary labels
b = np.where(a == 'C', 0, 1)

print(f'a = {a}')
print(f'b = {b}')
a = ['A' 'B' 'C' 'D' 'E']
b = [1 1 0 1 1]

Now, let’s say we have a three dimensional array of numbers of size (20, 4, 4) and we want to emulate the RELU function where all negative number of the array would be set to 0 and rest remain the same.

We will generate an array of random numbers from standard normal distribution for our example and apply the np.where function to do the transformation.

a = np.random.normal(size=(20, 4, 4))
b = np.where(a < 0, 0, a)

# print part of the arrays for understanding
print(a[0])
print(b[0])
[[ 1.45872533 -0.24965688 -1.11663205 -0.65852554]
 [-1.13076242 -0.49868332 -0.46350182 -0.02889719]
 [-0.99350298  0.88240974  0.87975654 -0.28836425]
 [-0.10684949 -0.88570172  1.70835701 -0.16105656]]
 
 [[1.45872533 0.         0.         0.        ]
 [0.         0.         0.         0.        ]
 [0.         0.88240974 0.87975654 0.        ]
 [0.         0.         1.70835701 0.        ]]

It is clear that the negative numbers in the array have been converted to 0.

We have only used a equal to (==) condition in above examples, but we may use any comparative operators as per our need. For example we can take log2 value of those numbers in an array which are greater than or equal to a certain number.

a = np.random.normal(size=(20, 4, 4)) * 10
a = a * 10
b = np.where(a >= 5, np.log2(a), a)

Privacy Overview
Analytics Notes

This website uses cookies so that we can provide you with the best user experience possible. Cookie information is stored in your browser and performs functions such as recognising you when you return to our website and helping our team to understand which sections of the website you find most interesting and useful.

Strictly Necessary Cookies

Strictly Necessary Cookie should be enabled at all times so that we can save your preferences for cookie settings.

If you disable this cookie, we will not be able to save your preferences. This means that every time you visit this website you will need to enable or disable cookies again.

3rd Party Cookies

This website uses Google Analytics to collect anonymous information such as the number of visitors to the site, and the most popular pages.

Keeping this cookie enabled helps us to improve our website.