Numpy in Python
- -
Variables and Literals¶
=
means assign.
a=5
means assign the numeric literal 5
to the variable a
.
a="Five"
means assign the string literal "Five"
to the variable a
.
a = 5
print("a =", a)
a = "Five"
print("a =", a)
a = 5
a = Five
Operators¶
x = 14
y = 4
# Add two operands
print('x + y =', x+y) # Output: x + y = 18
# Subtract right operand from the left
print('x - y =', x-y) # Output: x - y = 10
# Multiply two operands
print('x * y =', x*y) # Output: x * y = 56
# Divide left operand by the right one
print('x / y =', x/y) # Output: x / y = 3.5
# Floor division (quotient)
print('x // y =', x//y) # Output: x // y = 3
# Remainder of the division of left operand by the right
print('x % y =', x%y) # Output: x % y = 2
# Left operand raised to the power of right (x^y)
print('x ** y =', x**y) # Output: x ** y = 38416
x + y = 18
x - y = 10
x * y = 56
x / y = 3.5
x // y = 3
x % y = 2
x ** y = 38416
x = 5
# x += 5 ----> x = x + 5
x +=5
print(x) # Output: 10
# x /= 5 ----> x = x / 5
x /= 5
print(x) # Output: 2.0
10
2.0
Type Conversion¶
implicit Type Conversion¶
num_int = 123 # integer type
num_flo = 1.23 # float type
num_new = num_int + num_flo
print("Value of num_new:",num_new)
print("datatype of num_new:",type(num_new))
Value of num_new: 124.23
datatype of num_new: <class 'float'>
num_int = 123 # int type
num_str = "456" # str type
print(num_int+num_str)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [5], in <module>
1 num_int = 123 # int type
2 num_str = "456" # str type
----> 4 print(num_int+num_str)
TypeError: unsupported operand type(s) for +: 'int' and 'str'
Explicit Type Conversion¶
num_int = 123 # int type
num_str = "456" # str type
# explicitly converted to int type
num_str = int(num_str)
print(num_int+num_str)
579
Numeric Types¶
# Output: <class 'int'>
print(type(5))
# Output: <class 'float'>
print(type(5.0))
c = 5 + 3j
# Output: <class 'complex'>
print(type(c))
<class 'int'>
<class 'float'>
<class 'complex'>
Boolean Types¶
print(type(True))
print(True and True) # True
print(False and True) # False
print(True or False) # True
<class 'bool'>
True
False
True
Data Structures¶
Lists¶
# empty list
my_list = []
my_list = list()
# list of integers
my_list = [1, 2, 3]
# list with mixed data types
my_list = [1, "Hello", 3.4]
language = ["French", "German", "English", "Polish"]
# Accessing first element
print(language[0])
# Accessing fourth element
print(language[3])
# Accessing fourth element
print(language[-1])
# Get sub list (slice)
print(language[1:3])
French
Polish
Polish
['German', 'English']
[m:n]
means $m\leq i< n$.
language[1:3] = ['Korean']
print(language)
['French', 'Korean', 'Polish']
Tuples¶
Basically, lists are mutable whereas tuples are immutable.
language = ("French", "German", "English", "Polish")
print(language)
print(language[1]) #Output: German
print(language[3]) #Output: Polish
print(language[-1]) # Output: Polish
('French', 'German', 'English', 'Polish')
German
Polish
Polish
# TypeError: 'tuple' object does not support item assignment
language[1:3] = ['Korean']
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [13], in <module>
1 # TypeError: 'tuple' object does not support item assignment
----> 2 language[1:3] = ['Korean']
TypeError: 'tuple' object does not support item assignment
Strings¶
# all of the following are equivalent
my_string = 'Hello'
print(my_string)
my_string = "Hello"
print(my_string)
my_string = '''Hello'''
print(my_string)
# triple quotes string can extend multiple lines
my_string = """Hello, welcome to
the world of Python"""
print(my_string)
Hello
Hello
Hello
Hello, welcome to
the world of Python
str = 'Mathematics'
print('str = ', str)
print('str[0] = ', str[0]) # Output: M
print('str[-1] = ', str[-1]) # Output: s
# slicing 2nd to 5th character
print('str[1:5] = ', str[1:5]) # Output: athe
# slicing 6th to 2nd last character
print('str[5:-2] = ', str[5:-2]) # Output: mati
str = Mathematics
str[0] = M
str[-1] = s
str[1:5] = athe
str[5:-2] = mati
str1 = 'Hello '
str2 ='World!'
# Output: Hello World!
print(str1 + str2)
# Hello Hello Hello
print(str1 * 3)
Hello World!
Hello Hello Hello
Sets¶
# set of integers
my_set = {1, 2, 3}
print(my_set)
# set of mixed datatypes
my_set = {1.0, "Hello", (1, 2, 3)}
print(my_set)
{1, 2, 3}
{1.0, 'Hello', (1, 2, 3)}
# set of integers
my_set = {1, 2, 3}
my_set.add(4)
print(my_set) # Output: {1, 2, 3, 4}
my_set.add(2)
print(my_set) # Output: {1, 2, 3, 4}
my_set.update([3, 4, 5])
print(my_set) # Output: {1, 2, 3, 4, 5}
my_set.remove(4)
print(my_set) # Output: {1, 2, 3, 5}
{1, 2, 3, 4}
{1, 2, 3, 4}
{1, 2, 3, 4, 5}
{1, 2, 3, 5}
A = {1, 2, 3}
B = {2, 3, 4, 5}
# Equivalent to A.union(B)
# Also equivalent to B.union(A)
print(A | B) # Output: {1, 2, 3, 4, 5}
# Equivalent to A.intersection(B)
# Also equivalent to B.intersection(A)
print (A & B) # Output: {2, 3}
# Set Difference
print (A - B) # Output: {1}
# Set Symmetric Difference
print(A ^ B) # Output: {1, 4, 5}
{1, 2, 3, 4, 5}
{2, 3}
{1}
{1, 4, 5}
Dictionaries¶
# empty dictionary
my_dict = {}
# dictionary with integer keys
my_dict = {1: 'apple', 2: 'ball'}
# dictionary with mixed keys
my_dict = {'name': 'John', 1: [2, 4, 3]}
person = {'name':'Jack', 'age': 26, 'salary': 4534.2}
print(person['age']) # Output: 26
26
person = {'name':'Jack', 'age': 26}
# Changing age to 36
person['age'] = 36
print(person) # Output: {'name': 'Jack', 'age': 36}
# Adding salary key, value pair
person['salary'] = 4342.4
print(person) # Output: {'name': 'Jack', 'age': 36, 'salary': 4342.4}
# Deleting age
del person['age']
print(person) # Output: {'name': 'Jack', 'salary': 4342.4}
# Deleting entire dictionary
del person
{'name': 'Jack', 'age': 36}
{'name': 'Jack', 'age': 36, 'salary': 4342.4}
{'name': 'Jack', 'salary': 4342.4}
Python range()
¶
print(range(1, 10)) # Output: range(1, 10)
range(1, 10)
numbers = range(1, 6)
print(list(numbers)) # Output: [1, 2, 3, 4, 5]
print(tuple(numbers)) # Output: (1, 2, 3, 4, 5)
print(set(numbers)) # Output: {1, 2, 3, 4, 5}
[1, 2, 3, 4, 5]
(1, 2, 3, 4, 5)
{1, 2, 3, 4, 5}
# Equivalent to: numbers = range(1, 6)
numbers1 = range(1, 6 , 1)
print(list(numbers1)) # Output: [1, 2, 3, 4, 5]
numbers2 = range(1, 6, 2)
print(list(numbers2)) # Output: [1, 3, 5]
numbers3 = range(5, 0, -1)
print(list(numbers3)) # Output: [5, 4, 3, 2, 1]
[1, 2, 3, 4, 5]
[1, 3, 5]
[5, 4, 3, 2, 1]
Control Flow¶
if...else Statement¶
num = -1
if num > 0:
print("Positive number")
elif num == 0:
print("Zero")
else:
print("Negative number")
# Output: Negative number
Negative number
if False:
print("I am inside the body of if.")
print("I am also inside the body of if.")
print("I am outside the body of if")
# Output: I am outside the body of if.
I am outside the body of if
- Indentation is important!
while Loop¶
n = 100
# initialize sum and counter
sum = 0
i = 1
while i <= n:
sum = sum + i
i = i+1 # update counter
print("The sum is", sum)
# Output: The sum is 5050
The sum is 5050
for Loop¶
numbers = [6, 5, 3, 8, 4, 2]
sum = 0
# iterate over the list
for val in numbers:
sum = sum+val
print("The sum is", sum) # Output: The sum is 28
The sum is 28
break Statement¶
for val in "Mathematics":
if val == "e":
break
print(val)
else:
print("The end")
M
a
t
h
continue Statement¶
for val in "Mathematics":
if val == "e":
continue
print(val)
else:
print("The end")
M
a
t
h
m
a
t
i
c
s
The end
pass Statement¶
for val in "Mathematics":
pass
else:
print("The end")
The end
# IndentationError: expected an indented block
for val in "Mathematics":
else:
print("The end")
Input In [33]
else:
^
IndentationError: expected an indented block
Function¶
def print_lines():
print("I am line1.")
print("I am line2.")
# function call
print_lines()
I am line1.
I am line2.
def add_numbers(a, b):
sum = a + b
return sum
result = add_numbers(4, 5)
print(result)
# Output: 9
9
Decorator¶
- A function that returns a (modified) function.
def one_more(ftn):
def wrapper(a, b):
return ftn(a,b)+1
return wrapper
one_more(add_numbers)(4, 5)
10
@one_more
def add_numbers_one_more(a, b):
sum = a + b
return sum
add_numbers_one_more(4,5)
10
def more(n):
def outer(ftn):
def inner(a,b):
return ftn(a,b)+n
return inner
return outer
more(10)(add_numbers)(4,5)
19
@more(10)
def add_numbers_ten_more(a,b):
sum = a + b
return sum
add_numbers_ten_more(4,5)
19
Lambda Function¶
square = lambda x: x ** 2
print(square(5))
# Output: 25
25
numbers = [6, 5, 3, 8, 4, 2]
numbers2 = map(lambda x: x ** 2, numbers)
print(list(numbers2))
[36, 25, 9, 64, 16, 4]
numbers_lt_5 = filter(lambda x: x<5, numbers)
print(list(numbers_lt_5))
[3, 4, 2]
Generators¶
# A simple generator function
def my_gen():
n = 1
print('This is printed first')
# Generator function contains yield statements
yield n
n += 1
print('This is printed second')
yield n
n += 1
print('This is printed at last')
yield n
# It returns an object but does not start execution immediately.
a = my_gen()
# We can iterate through the items using next().
print(next(a)) # 1
# Once the function yields, the function is paused and the control is transferred to the caller.
# Local variables and theirs states are remembered between successive calls.
print(next(a)) # 2
print(next(a)) # 3
# Finally, when the function terminates, StopIteration is raised automatically on further calls.
print(next(a))
This is printed first
1
This is printed second
2
This is printed at last
3
---------------------------------------------------------------------------
StopIteration Traceback (most recent call last)
Input In [43], in <module>
26 print(next(a)) # 3
28 # Finally, when the function terminates, StopIteration is raised automatically on further calls.
---> 29 print(next(a))
StopIteration:
# Using for loop
for item in my_gen():
print(item)
This is printed first
1
This is printed second
2
This is printed at last
3
Comprehension¶
# list comprehension
numbers = [6, 5, 3, 8, 4, 2]
print([x**2 for x in numbers])
[36, 25, 9, 64, 16, 4]
# generator expression
numbers = [6, 5, 3, 8, 4, 2]
print((x**2 for x in numbers))
<generator object <genexpr> at 0x10cc7cc80>
for i in (x**2 for x in numbers):
print(i)
36
25
9
64
16
4
Class¶
class MyClass():
def __init__(self, a):
self.a = a
def add(self, b):
self.a = self.a + b
return self
obj1 = MyClass(10)
print(obj1.a) # 10
obj1.add(3)
print(obj1.a) # 13
obj2 = MyClass(15)
print(obj2.a) # 15
10
13
15
class MyClass2(MyClass):
def mul(self, b):
self.a = self.a * b
return self
obj3 = MyClass2(10)
print(obj3.a) # 10
obj3.mul(3).add(5)
print(obj3.a) # 35
10
35
Annotation¶
class MyClass():
a: float
def __init__(self, a:float):
self.a = a
def add(self, b:float):
self.a += b
return self
obj = MyClass(3.0)
print(obj.a)
obj.add(2.0).add(3.0)
print(obj.a)
3.0
8.0
- Annotation is NOT static typing.
obj = MyClass(3)
print(obj.a)
obj.add(2).add(3)
print(obj.a)
3
8
Data Class¶
from dataclasses import dataclass
@dataclass
class InventoryItem:
"""Class for keeping track of an item in inventory."""
name: str
unit_price: float
quantity_on_hand: int = 0
# We DO NOT need to define __init__().
# def __init__(self, name: str, unit_price: float, quantity_on_hand: int = 0):
# self.name = name
# self.unit_price = unit_price
# self.quantity_on_hand = quantity_on_hand
def total_cost(self) -> float:
return self.unit_price * self.quantity_on_hand
item1 = InventoryItem('Apple', 300, 5)
print(item1.total_cost())
1500
Numpy Quickstart¶
import numpy as np
Shape of an array¶
ndarray.ndim
will tell you the number of axes, or dimensions, of the array.
ndarray.size
will tell you the total number of elements of the array. This is the product of the elements of the array’s shape.
ndarray.shape
will display a tuple of integers that indicate the number of elements stored along each dimension of the array. If, for example, you have a 2-D array with 2 rows and 3 columns, the shape of your array is (2, 3).
array_example = np.array([[[0, 1, 2, 3],
[4, 5, 6, 7]],
[[0, 1, 2, 3],
[4, 5, 6, 7]],
[[0 ,1 ,2, 3],
[4, 5, 6, 7]]])
array_example
array([[[0, 1, 2, 3],
[4, 5, 6, 7]],
[[0, 1, 2, 3],
[4, 5, 6, 7]],
[[0, 1, 2, 3],
[4, 5, 6, 7]]])
array_example.ndim
3
array_example.size
24
array_example.shape
(3, 2, 4)
array_example.reshape((3,8))
array([[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7]])
array_example.reshape((6,4))
array([[0, 1, 2, 3],
[4, 5, 6, 7],
[0, 1, 2, 3],
[4, 5, 6, 7],
[0, 1, 2, 3],
[4, 5, 6, 7]])
Indexing and Slicing¶
array_example = np.array([[[0, 1, 2, 3],
[4, 5, 6, 7]],
[[0, 1, 2, 3],
[4, 5, 6, 7]],
[[0 ,1 ,2, 3],
[4, 5, 6, 7]]])
array_example
array([[[0, 1, 2, 3],
[4, 5, 6, 7]],
[[0, 1, 2, 3],
[4, 5, 6, 7]],
[[0, 1, 2, 3],
[4, 5, 6, 7]]])
array_example[1]
array([[0, 1, 2, 3],
[4, 5, 6, 7]])
array_example[1,0,0]
0
array_example[1,:,1:3]
array([[1, 2],
[5, 6]])
Basic Array Operations¶
- Element-wise operations
data = np.array([1, 2])
ones = np.ones(2, dtype=int)
data, ones
(array([1, 2]), array([1, 1]))
data + ones
array([2, 3])
data - ones
array([0, 1])
data * data
array([1, 4])
data / data
array([1., 1.])
data ** 2
array([1, 4])
a = np.array([1, 2, 3, 4])
a.sum()
10
b = np.array([[1, 1], [2, 2]])
b
array([[1, 1],
[2, 2]])
b.sum(axis=0)
array([3, 3])
b.sum(axis=1)
array([2, 4])
b.max()
2
a = np.array([[0.45053314, 0.17296777, 0.34376245, 0.5510652],
[0.54627315, 0.05093587, 0.40067661, 0.55645993],
[0.12697628, 0.82485143, 0.26590556, 0.56917101]])
a
array([[0.45053314, 0.17296777, 0.34376245, 0.5510652 ],
[0.54627315, 0.05093587, 0.40067661, 0.55645993],
[0.12697628, 0.82485143, 0.26590556, 0.56917101]])
a.min(), a.max(), a.sum(), a.mean()
(0.05093587, 0.82485143, 4.8595784, 0.4049648666666667)
a.shape
(3, 4)
a.min(axis=0)
array([0.12697628, 0.05093587, 0.26590556, 0.5510652 ])
Broadcasting¶
array_example = np.array([[[0, 1, 2, 3],
[4, 5, 6, 7]],
[[0, 1, 2, 3],
[4, 5, 6, 7]],
[[0 ,1 ,2, 3],
[4, 5, 6, 7]]])
array_example
array([[[0, 1, 2, 3],
[4, 5, 6, 7]],
[[0, 1, 2, 3],
[4, 5, 6, 7]],
[[0, 1, 2, 3],
[4, 5, 6, 7]]])
Multiply scalar.
array_example * 3
array([[[ 0, 3, 6, 9],
[12, 15, 18, 21]],
[[ 0, 3, 6, 9],
[12, 15, 18, 21]],
[[ 0, 3, 6, 9],
[12, 15, 18, 21]]])
Consider $\mathbb R^{3\times 2\times 4}$ as $((\mathbb R^4)^2)^3$-vector space.
An element of $\mathbb R^4$ is a scalar in $(\mathbb R^4)^2$ and $((\mathbb R^4)^2)^3$.
array_example.shape
(3, 2, 4)
array_example * np.array([0,1,2,3])
array([[[ 0, 1, 4, 9],
[ 0, 5, 12, 21]],
[[ 0, 1, 4, 9],
[ 0, 5, 12, 21]],
[[ 0, 1, 4, 9],
[ 0, 5, 12, 21]]])
An element of $(\mathbb R^4)^2=\mathbb R^{2\times 4}$ is a scalar in $((\mathbb R^4)^2)^3=\mathbb R^{3\times 2\times 4}$.
array_example * np.array([[0,1,2,3], [4,5,6,7]])
array([[[ 0, 1, 4, 9],
[16, 25, 36, 49]],
[[ 0, 1, 4, 9],
[16, 25, 36, 49]],
[[ 0, 1, 4, 9],
[16, 25, 36, 49]]])
We can re-order the powers, so $\mathbb R^{3\times 2\times 4}$ can be considered as $((\mathbb R^{2})^3)^4$.
So an element of $\mathbb R^2=\mathbb R^{1\times 2\times 1}$ is a scalar in $\mathbb R^{3\times 2\times 4}$.
np.array([[[0],[1]]])
array([[[0],
[1]]])
np.array([[[0],[1]]]).shape
(1, 2, 1)
array_example * np.array([[[0],[1]]])
array([[[0, 0, 0, 0],
[4, 5, 6, 7]],
[[0, 0, 0, 0],
[4, 5, 6, 7]],
[[0, 0, 0, 0],
[4, 5, 6, 7]]])
np.array([0,1])[np.newaxis,:,np.newaxis]
array([[[0],
[1]]])
We can re-order the powers, so $\mathbb R^{3\times 2\times 4}$ can be considered as $((\mathbb R^3)^2)^4$.
So an element of $\mathbb R^3=\mathbb R^{3\times 1\times 1}$ is a scalar in $\mathbb R^{3\times 2\times 4}$.
array_example * np.array([0,1,0])[:,np.newaxis,np.newaxis]
array([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 1, 2, 3],
[4, 5, 6, 7]],
[[0, 0, 0, 0],
[0, 0, 0, 0]]])
array_example + np.array([0,1,2])[:,np.newaxis,np.newaxis]
array([[[0, 1, 2, 3],
[4, 5, 6, 7]],
[[1, 2, 3, 4],
[5, 6, 7, 8]],
[[2, 3, 4, 5],
[6, 7, 8, 9]]])
Linear Algebra¶
Matrix Product¶
A = np.random.rand(3,4)
B = np.random.rand(4,5)
np.matmul(A,B)
array([[1.12530675, 0.67299837, 0.80348367, 0.87924637, 1.09648698],
[1.23792231, 0.68697281, 0.67662029, 0.75067197, 1.20216618],
[2.01218187, 1.27456069, 1.5051888 , 1.76045733, 1.89066128]])
A@B
array([[1.12530675, 0.67299837, 0.80348367, 0.87924637, 1.09648698],
[1.23792231, 0.68697281, 0.67662029, 0.75067197, 1.20216618],
[2.01218187, 1.27456069, 1.5051888 , 1.76045733, 1.89066128]])
For batched input
x = np.random.rand(128,3,4)
W = np.random.rand(4,5)
np.matmul(x,W).shape
(128, 3, 5)
Einstein Summation Convention¶
Given two tensors $\mathbb u, \mathbb v$ with dimensions $(k_1,\ldots,k_m)$ and $(l_1,\ldots,l_n)$, their outer product $\mathbb u\otimes\mathbb v$ is a tensor with dimensions $(k_1,\ldots,k_m,l_1,\ldots,l_n)$ and entries $$(\mathbb u\otimes\mathbb v)_{i_1,\ldots,i_m,j_1,\ldots,j_n} = u_{i_1,\ldots,i_m} v_{j_1,\ldots,j_n}.$$
- $\mathbb a\in\mathbb R^3, \mathbb b\in\mathbb R^3, (\mathbb a\otimes\mathbb b)\in\mathbb R^{3\times 3}$
- $(\mathbb a\otimes\mathbb b)_{i,j} = a_ib_j$
import numpy as np
a = np.array([1,2,3])
b = np.array([5,6,7])
a[:,np.newaxis] * b[np.newaxis,:]
array([[ 5, 6, 7],
[10, 12, 14],
[15, 18, 21]])
np.einsum('i,j->ij', a, b)
array([[ 5, 6, 7],
[10, 12, 14],
[15, 18, 21]])
If we omit an index in rhs of ->, then the omitted axis is summed up.
np.einsum('i,j->i', a, b)
- $\sum_j(\mathbb a\otimes\mathbb b)_{i,j} = \sum_ja_ib_j$
(a[:,np.newaxis] * b[np.newaxis,:]).sum(axis=1)
array([18, 36, 54])
np.einsum('i,j->i', a, b)
array([18, 36, 54])
np.einsum('i,j->j', a, b)
- $\sum_i(\mathbb a\otimes\mathbb b)_{i,j} = \sum_ia_ib_j$
(a[:,np.newaxis] * b[np.newaxis,:]).sum(axis=0)
array([30, 36, 42])
np.einsum('i,j->j', a, b)
array([30, 36, 42])
If the same index are occurred in lhs of ->, then we choose diagonal.
np.einsum('i,i->i', a, b)
- $(\mathbb a\otimes\mathbb b)_{i,i}=a_ib_i$
a*b
array([ 5, 12, 21])
np.einsum('i,i->i', a, b)
array([ 5, 12, 21])
np.einsum('i,i->', a, b)
- $\sum_i(\mathbb a\otimes\mathbb b)_{i,i}=\sum_ia_ib_i=\mathbb a\cdot\mathbb b$
np.dot(a,b)
38
(a*b).sum()
38
np.einsum('i,i->', a, b)
38
$N$-D case ($N\geq 3$)
- $\mathbb a\in\mathbb R^3, B\in\mathbb R^{2\times 3}, \mathbb a\otimes B\in\mathbb R^{3\times(2\times 3)}$
- $(\mathbb a\otimes B)_{i,j,k} = a_ib_{j,k}$
a = np.array([0,1,2])
B = np.array([[1,2,3],[4,5,6]])
a[:,np.newaxis,np.newaxis] * B[np.newaxis,:,:]
array([[[ 0, 0, 0],
[ 0, 0, 0]],
[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 2, 4, 6],
[ 8, 10, 12]]])
np.einsum('i,jk->ijk', a, B)
array([[[ 0, 0, 0],
[ 0, 0, 0]],
[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 2, 4, 6],
[ 8, 10, 12]]])
np.einsum('i,jk->ijk', a, B).shape
(3, 2, 3)
- $(\mathbb a\otimes B)_{i,j,i}=a_ib_{j,i}$
np.stack([(a[:,np.newaxis,np.newaxis] * B[np.newaxis,:,:])[i,:,i] for i in range(3)])
array([[ 0, 0],
[ 2, 5],
[ 6, 12]])
np.einsum('i,ji->ij', a, B)
array([[ 0, 0],
[ 2, 5],
[ 6, 12]])
- $\sum_j(\mathbb a\otimes B)_{i,j,i}=\sum_ja_ib_{j,i}=\sum_iB_{i,j}a_j=B\cdot\mathbb a$
np.dot(B,a)
array([ 8, 17])
np.einsum('ij,j->i', B, a)
array([ 8, 17])
np.einsum('i,ji->j', a, B)
array([ 8, 17])
Batch matrix product
- $x\in\mathbb R^{128\times (3\times 4)}, W\in\mathbb R^{4\times 5}$
- $\sum_j(x\otimes W)_{b,i,j,j,k} = \sum_jx_{b,i,j}W_{j,k}$
x = np.random.rand(128,3,4)
W = np.random.rand(4,5)
np.matmul(x,W).shape
(128, 3, 5)
np.einsum('bij,jk->bik', x, W).shape
(128, 3, 5)
np.einsum('...j,jk->...k', x, W).shape
(128, 3, 5)
np.allclose(np.matmul(x,W), np.einsum('bij,jk->bik', x, W))
True
np.allclose(np.matmul(x,W), np.einsum('...j,jk->...k', x, W))
True
'Coding > Python' 카테고리의 다른 글
Crash Cource in Python (0) | 2024.11.03 |
---|---|
FastAPI를 이용한 웹캠 스트리밍 서버 (0) | 2024.10.29 |
CS231n Python Tutorial (0) | 2024.09.10 |
다른 폴더 파일 import (0) | 2024.04.09 |
상위 폴더 파일 import (0) | 2023.03.06 |
소중한 공감 감사합니다