当前位置:优学网  >  在线题库

使用np查找阈值。argwhere

发表时间:2022-07-11 01:15:47 阅读:84

我有以下格式的numpy数组:

array([list([0.28552457, 0.28552457, 0.28552457, 0.28552457, 0.28552457]),
       list([0.71641791, 0.71641791, 0.71641791, 0.69565217, 0.69565217]),
       list([0.95626478, 0.95626478, 0.95513577, 0.95513577, 0.95513577]),
       ...,
       list([0.14285714, 0.14285714, 0.14285714, 0.14285714, 0.13793103]),
       list([0.73846154, 0.73846154, 0.73846154, 0.71641791, 0.71641791]),
       list([0.72727273, 0.72727273, 0.72727273, 0.70588235, 0.70588235])],
      dtype=object)

我如何使用

 np.argwhere(y>0.5)

对于上述numpy阵列.

我在使用np时出错.argwhere

TypeErrorTraceback (most recent call last) Input In [132], in <cell line: 1>() ----> 1 z=np.argwhere(y>0.5)  TypeError: '>' not supported between instances of 'list' and 'float'
🎖️ 优质答案
  • 使用以下代码生成numpy数组

    y = np.random.random(50).reshape(10, 5)
    

    应该给你这样的东西

    array([[0.0992988 , 0.62179431, 0.26247934, 0.84402507, 0.05931778],
           [0.41603546, 0.6375162 , 0.66391551, 0.10959321, 0.71086281],
           [0.94060165, 0.84174581, 0.68389615, 0.20276522, 0.88082071],
           [0.04820747, 0.66052068, 0.81348755, 0.67832623, 0.96918592],
           [0.02804541, 0.47816843, 0.30916056, 0.38798535, 0.52484326],
           [0.06011845, 0.51549552, 0.76312676, 0.44300283, 0.52805978],
           [0.33544995, 0.8518729 , 0.59308601, 0.04525118, 0.5162366 ],
           [0.04688691, 0.34093081, 0.07197314, 0.868233  , 0.9434406 ],
           [0.63620516, 0.63382467, 0.2054274 , 0.01997156, 0.90429261],
           [0.68956099, 0.37559233, 0.96284643, 0.32257647, 0.31922311]])
    

    你可以打电话给np.argwhere(y&gt0.5)`获取大于0.5的所有值的索引.这是我得到的输出.

    array([[0, 1],
           [0, 3],
           [1, 1],
           [1, 2],
           [1, 4],
           [2, 0],
           [2, 1],
           [2, 2],
           [2, 4],
           [3, 1],
           [3, 2],
           [3, 3],
           [3, 4],
           [4, 4],
           [5, 1],
           [5, 2],
           [5, 4],
           [6, 1],
           [6, 2],
           [6, 4],
           [7, 3],
           [7, 4],
           [8, 0],
           [8, 1],
           [8, 4],
           [9, 0],
           [9, 2]])
    

    你可以在numpy上阅读更多内容.argwhere`此处.

  • I changed your input and made some lists shorter than the rest. Doing that, the lists don't get converted to a np.array as you see in my print statements on arr. Use the code with itertools.zip_longest, then you should be good to go. The fill_value is your choice what fits you the best.

    arr = np.array([list([0.28552457, 0.28552457, 0.28552457, 0.28552457, 0.28552457]),
           list([0.71641791, 0.71641791, 0.71641791, 0.69565217, ]),
           list([0.95626478, 0.95626478, 0.95513577, 0.95513577, 0.95513577]),
           list([0.14285714, 0.14285714, 0.14285714, 0.14285714, ]),
           list([0.73846154, 0.73846154, 0.73846154, 0.71641791, 0.71641791]),
           list([0.72727273, 0.72727273, 0.72727273, 0.70588235])], dtype='object')
    print(type(arr))
    print(type(arr[0]))
    print(arr)
    
    <class 'numpy.ndarray'>
    <class 'list'>
    [list([0.28552457, 0.28552457, 0.28552457, 0.28552457, 0.28552457])
     list([0.71641791, 0.71641791, 0.71641791, 0.69565217])
     list([0.95626478, 0.95626478, 0.95513577, 0.95513577, 0.95513577])
     list([0.14285714, 0.14285714, 0.14285714, 0.14285714])
     list([0.73846154, 0.73846154, 0.73846154, 0.71641791, 0.71641791])
     list([0.72727273, 0.72727273, 0.72727273, 0.70588235])]
    
    converted_arr = np.array(list(itertools.zip_longest(*arr, fillvalue=-1))).T
    print(type(converted_arr))
    print(type(converted_arr[0]))
    print(converted_arr)
    
    <class 'numpy.ndarray'>
    <class 'numpy.ndarray'>
    [[ 0.28552457  0.28552457  0.28552457  0.28552457  0.28552457]
     [ 0.71641791  0.71641791  0.71641791  0.69565217 -1.        ]
     [ 0.95626478  0.95626478  0.95513577  0.95513577  0.95513577]
     [ 0.14285714  0.14285714  0.14285714  0.14285714 -1.        ]
     [ 0.73846154  0.73846154  0.73846154  0.71641791  0.71641791]
     [ 0.72727273  0.72727273  0.72727273  0.70588235 -1.        ]]

    and then as you already provided in your question:

    np.argwhere(converted_arr>0.5)
    
    array([[1, 0],
           [1, 1],
           [1, 2],
           [1, 3],
           [1, 4],
           [2, 0],
           [2, 1],
           [2, 2],
           [2, 3],
           [2, 4],
           [4, 0],
           [4, 1],
           [4, 2],
           [4, 3],
           [4, 4],
           [5, 0],
           [5, 1],
           [5, 2],
           [5, 3],
           [5, 4]], dtype=int64)
  • 相关问题