我有以下格式的numpy数组:
array([list([0.28552457, 0.28552457, 0.28552457, 0.28552457, 0.28552457]),
list([0.71641791, 0.71641791, 0.71641791, 0.69565217, 0.69565217]),
list([0.95626478, 0.95626478, 0.95513577, 0.95513577, 0.95513577]),
...,
list([0.14285714, 0.14285714, 0.14285714, 0.14285714, 0.13793103]),
list([0.73846154, 0.73846154, 0.73846154, 0.71641791, 0.71641791]),
list([0.72727273, 0.72727273, 0.72727273, 0.70588235, 0.70588235])],
dtype=object)
我如何使用
np.argwhere(y>0.5)
对于上述numpy阵列.
我在使用np时出错.argwhere
TypeErrorTraceback (most recent call last) Input In [132], in <cell line: 1>() ----> 1 z=np.argwhere(y>0.5) TypeError: '>' not supported between instances of 'list' and 'float'
I changed your input and made some lists shorter than the rest. Doing that, the lists don't get converted to a np.array
as you see in my print statements on arr
.
Use the code with itertools.zip_longest
, then you should be good to go. The fill_value
is your choice what fits you the best.
arr = np.array([list([0.28552457, 0.28552457, 0.28552457, 0.28552457, 0.28552457]),
list([0.71641791, 0.71641791, 0.71641791, 0.69565217, ]),
list([0.95626478, 0.95626478, 0.95513577, 0.95513577, 0.95513577]),
list([0.14285714, 0.14285714, 0.14285714, 0.14285714, ]),
list([0.73846154, 0.73846154, 0.73846154, 0.71641791, 0.71641791]),
list([0.72727273, 0.72727273, 0.72727273, 0.70588235])], dtype='object')
print(type(arr))
print(type(arr[0]))
print(arr)
<class 'numpy.ndarray'>
<class 'list'>
[list([0.28552457, 0.28552457, 0.28552457, 0.28552457, 0.28552457])
list([0.71641791, 0.71641791, 0.71641791, 0.69565217])
list([0.95626478, 0.95626478, 0.95513577, 0.95513577, 0.95513577])
list([0.14285714, 0.14285714, 0.14285714, 0.14285714])
list([0.73846154, 0.73846154, 0.73846154, 0.71641791, 0.71641791])
list([0.72727273, 0.72727273, 0.72727273, 0.70588235])]
converted_arr = np.array(list(itertools.zip_longest(*arr, fillvalue=-1))).T
print(type(converted_arr))
print(type(converted_arr[0]))
print(converted_arr)
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
[[ 0.28552457 0.28552457 0.28552457 0.28552457 0.28552457]
[ 0.71641791 0.71641791 0.71641791 0.69565217 -1. ]
[ 0.95626478 0.95626478 0.95513577 0.95513577 0.95513577]
[ 0.14285714 0.14285714 0.14285714 0.14285714 -1. ]
[ 0.73846154 0.73846154 0.73846154 0.71641791 0.71641791]
[ 0.72727273 0.72727273 0.72727273 0.70588235 -1. ]]
and then as you already provided in your question:
np.argwhere(converted_arr>0.5)
array([[1, 0],
[1, 1],
[1, 2],
[1, 3],
[1, 4],
[2, 0],
[2, 1],
[2, 2],
[2, 3],
[2, 4],
[4, 0],
[4, 1],
[4, 2],
[4, 3],
[4, 4],
[5, 0],
[5, 1],
[5, 2],
[5, 3],
[5, 4]], dtype=int64)
使用以下代码生成numpy数组
应该给你这样的东西
你可以打电话给np.argwhere(y>0.5)`获取大于0.5的所有值的索引.这是我得到的输出.
你可以在numpy上阅读更多内容.argwhere`此处.