"""
Template for each `dtype` helper function for hashtable

WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in
"""

{{py:

# name, dtype, ttype, c_type, to_c_type
dtypes = [('Complex128', 'complex128', 'complex128',
                         'khcomplex128_t', 'to_khcomplex128_t'),
          ('Complex64', 'complex64', 'complex64',
                        'khcomplex64_t', 'to_khcomplex64_t'),
          ('Float64', 'float64', 'float64', 'float64_t', ''),
          ('Float32', 'float32', 'float32', 'float32_t', ''),
          ('UInt64', 'uint64', 'uint64', 'uint64_t', ''),
          ('UInt32', 'uint32', 'uint32', 'uint32_t', ''),
          ('UInt16', 'uint16', 'uint16', 'uint16_t', ''),
          ('UInt8', 'uint8', 'uint8', 'uint8_t', ''),
          ('Object', 'object', 'pymap', 'object', '<PyObject*>'),
          ('Int64', 'int64', 'int64', 'int64_t', ''),
          ('Int32', 'int32', 'int32', 'int32_t', ''),
          ('Int16', 'int16', 'int16', 'int16_t', ''),
          ('Int8', 'int8', 'int8', 'int8_t', '')]

}}

{{for name, dtype, ttype, c_type, to_c_type in dtypes}}


@cython.wraparound(False)
@cython.boundscheck(False)
{{if dtype == 'object'}}
cdef value_count_{{dtype}}(ndarray[{{dtype}}] values, bint dropna, const uint8_t[:] mask=None):
{{else}}
cdef value_count_{{dtype}}(const {{dtype}}_t[:] values, bint dropna, const uint8_t[:] mask=None):
{{endif}}
    cdef:
        Py_ssize_t i = 0
        Py_ssize_t n = len(values)
        kh_{{ttype}}_t *table

        # Don't use Py_ssize_t, since table.n_buckets is unsigned
        khiter_t k

        {{c_type}} val

        int ret = 0
        bint uses_mask = mask is not None
        bint isna_entry = False

    if uses_mask and not dropna:
        raise NotImplementedError("uses_mask not implemented with dropna=False")

    # we track the order in which keys are first seen (GH39009),
    # khash-map isn't insertion-ordered, thus:
    #    table maps keys to counts
    #    result_keys remembers the original order of keys

    result_keys = {{name}}Vector()
    table = kh_init_{{ttype}}()

    {{if dtype == 'object'}}
    if uses_mask:
        raise NotImplementedError("uses_mask not implemented with object dtype")

    kh_resize_{{ttype}}(table, n // 10)

    for i in range(n):
        val = values[i]
        if not dropna or not checknull(val):
            k = kh_get_{{ttype}}(table, {{to_c_type}}val)
            if k != table.n_buckets:
                table.vals[k] += 1
            else:
                k = kh_put_{{ttype}}(table, {{to_c_type}}val, &ret)
                table.vals[k] = 1
                result_keys.append(val)
    {{else}}
    kh_resize_{{ttype}}(table, n)

    for i in range(n):
        val = {{to_c_type}}(values[i])

        if dropna:
            if uses_mask:
                isna_entry = mask[i]
            else:
                isna_entry = is_nan_{{c_type}}(val)

        if not dropna or not isna_entry:
            k = kh_get_{{ttype}}(table, val)
            if k != table.n_buckets:
                table.vals[k] += 1
            else:
                k = kh_put_{{ttype}}(table, val, &ret)
                table.vals[k] = 1
                result_keys.append(val)
    {{endif}}

    # collect counts in the order corresponding to result_keys:
    cdef:
        int64_t[::1] result_counts = np.empty(table.size, dtype=np.int64)

    for i in range(table.size):
        {{if dtype == 'object'}}
        k = kh_get_{{ttype}}(table, result_keys.data[i])
        {{else}}
        k = kh_get_{{ttype}}(table, result_keys.data.data[i])
        {{endif}}
        result_counts[i] = table.vals[k]

    kh_destroy_{{ttype}}(table)

    return result_keys.to_array(), result_counts.base


@cython.wraparound(False)
@cython.boundscheck(False)
{{if dtype == 'object'}}
cdef duplicated_{{dtype}}(ndarray[{{dtype}}] values, object keep='first', const uint8_t[:] mask=None):
{{else}}
cdef duplicated_{{dtype}}(const {{dtype}}_t[:] values, object keep='first', const uint8_t[:] mask=None):
{{endif}}
    cdef:
        int ret = 0
        {{if dtype != 'object'}}
        {{c_type}} value
        {{else}}
        PyObject* value
        {{endif}}
        Py_ssize_t i, n = len(values), first_na = -1
        khiter_t k
        kh_{{ttype}}_t *table = kh_init_{{ttype}}()
        ndarray[uint8_t, ndim=1, cast=True] out = np.empty(n, dtype='bool')
        bint seen_na = False, uses_mask = mask is not None
        bint seen_multiple_na = False

    kh_resize_{{ttype}}(table, min(kh_needed_n_buckets(n), SIZE_HINT_LIMIT))

    if keep not in ('last', 'first', False):
        raise ValueError('keep must be either "first", "last" or False')

    {{for cond, keep in [('if', '"last"'), ('elif', '"first"')]}}
    {{cond}} keep == {{keep}}:
        {{if dtype == 'object'}}
        if True:
        {{else}}
        with nogil:
        {{endif}}
            {{if keep == '"last"'}}
            for i in range(n - 1, -1, -1):
            {{else}}
            for i in range(n):
            {{endif}}
                if uses_mask and mask[i]:
                    if seen_na:
                        out[i] = True
                    else:
                        out[i] = False
                        seen_na = True
                else:
                    value = {{to_c_type}}(values[i])
                    kh_put_{{ttype}}(table, value, &ret)
                    out[i] = ret == 0
    {{endfor}}

    else:
        {{if dtype == 'object'}}
        if True:
        {{else}}
        with nogil:
        {{endif}}
            for i in range(n):
                if uses_mask and mask[i]:
                    if not seen_na:
                        first_na = i
                        seen_na = True
                        out[i] = 0
                    elif not seen_multiple_na:
                        out[i] = 1
                        out[first_na] = 1
                        seen_multiple_na = True
                    else:
                        out[i] = 1

                else:
                    value = {{to_c_type}}(values[i])
                    k = kh_get_{{ttype}}(table, value)
                    if k != table.n_buckets:
                        out[table.vals[k]] = 1
                        out[i] = 1
                    else:
                        k = kh_put_{{ttype}}(table, value, &ret)
                        table.vals[k] = i
                        out[i] = 0

    kh_destroy_{{ttype}}(table)
    return out


# ----------------------------------------------------------------------
# Membership
# ----------------------------------------------------------------------


@cython.wraparound(False)
@cython.boundscheck(False)
{{if dtype == 'object'}}
cdef ismember_{{dtype}}(ndarray[{{c_type}}] arr, ndarray[{{c_type}}] values):
{{else}}
cdef ismember_{{dtype}}(const {{dtype}}_t[:] arr, const {{dtype}}_t[:] values):
{{endif}}
    """
    Return boolean of values in arr on an
    element by-element basis

    Parameters
    ----------
    arr : {{dtype}} ndarray
    values : {{dtype}} ndarray

    Returns
    -------
    boolean ndarray len of (arr)
    """
    cdef:
        Py_ssize_t i, n
        khiter_t k
        int ret = 0
        ndarray[uint8_t] result

        {{if dtype == "object"}}
        PyObject* val
        {{else}}
        {{c_type}} val
        {{endif}}

        kh_{{ttype}}_t *table = kh_init_{{ttype}}()

    # construct the table
    n = len(values)
    kh_resize_{{ttype}}(table, n)

    {{if dtype == 'object'}}
    if True:
    {{else}}
    with nogil:
    {{endif}}
        for i in range(n):
            val = {{to_c_type}}(values[i])
            kh_put_{{ttype}}(table, val, &ret)

    # test membership
    n = len(arr)
    result = np.empty(n, dtype=np.uint8)

    {{if dtype == 'object'}}
    if True:
    {{else}}
    with nogil:
    {{endif}}
        for i in range(n):
            val = {{to_c_type}}(arr[i])
            k = kh_get_{{ttype}}(table, val)
            result[i] = (k != table.n_buckets)

    kh_destroy_{{ttype}}(table)
    return result.view(np.bool_)

# ----------------------------------------------------------------------
# Mode Computations
# ----------------------------------------------------------------------

{{endfor}}


ctypedef fused htfunc_t:
    numeric_object_t
    complex128_t
    complex64_t


cpdef value_count(ndarray[htfunc_t] values, bint dropna, const uint8_t[:] mask=None):
    if htfunc_t is object:
        return value_count_object(values, dropna, mask=mask)

    elif htfunc_t is int8_t:
        return value_count_int8(values, dropna, mask=mask)
    elif htfunc_t is int16_t:
        return value_count_int16(values, dropna, mask=mask)
    elif htfunc_t is int32_t:
        return value_count_int32(values, dropna, mask=mask)
    elif htfunc_t is int64_t:
        return value_count_int64(values, dropna, mask=mask)

    elif htfunc_t is uint8_t:
        return value_count_uint8(values, dropna, mask=mask)
    elif htfunc_t is uint16_t:
        return value_count_uint16(values, dropna, mask=mask)
    elif htfunc_t is uint32_t:
        return value_count_uint32(values, dropna, mask=mask)
    elif htfunc_t is uint64_t:
        return value_count_uint64(values, dropna, mask=mask)

    elif htfunc_t is float64_t:
        return value_count_float64(values, dropna, mask=mask)
    elif htfunc_t is float32_t:
        return value_count_float32(values, dropna, mask=mask)

    elif htfunc_t is complex128_t:
        return value_count_complex128(values, dropna, mask=mask)
    elif htfunc_t is complex64_t:
        return value_count_complex64(values, dropna, mask=mask)

    else:
        raise TypeError(values.dtype)


cpdef duplicated(ndarray[htfunc_t] values, object keep="first", const uint8_t[:] mask=None):
    if htfunc_t is object:
        return duplicated_object(values, keep, mask=mask)

    elif htfunc_t is int8_t:
        return duplicated_int8(values, keep, mask=mask)
    elif htfunc_t is int16_t:
        return duplicated_int16(values, keep, mask=mask)
    elif htfunc_t is int32_t:
        return duplicated_int32(values, keep, mask=mask)
    elif htfunc_t is int64_t:
        return duplicated_int64(values, keep, mask=mask)

    elif htfunc_t is uint8_t:
        return duplicated_uint8(values, keep, mask=mask)
    elif htfunc_t is uint16_t:
        return duplicated_uint16(values, keep, mask=mask)
    elif htfunc_t is uint32_t:
        return duplicated_uint32(values, keep, mask=mask)
    elif htfunc_t is uint64_t:
        return duplicated_uint64(values, keep, mask=mask)

    elif htfunc_t is float64_t:
        return duplicated_float64(values, keep, mask=mask)
    elif htfunc_t is float32_t:
        return duplicated_float32(values, keep, mask=mask)

    elif htfunc_t is complex128_t:
        return duplicated_complex128(values, keep, mask=mask)
    elif htfunc_t is complex64_t:
        return duplicated_complex64(values, keep, mask=mask)

    else:
        raise TypeError(values.dtype)


cpdef ismember(ndarray[htfunc_t] arr, ndarray[htfunc_t] values):
    if htfunc_t is object:
        return ismember_object(arr, values)

    elif htfunc_t is int8_t:
        return ismember_int8(arr, values)
    elif htfunc_t is int16_t:
        return ismember_int16(arr, values)
    elif htfunc_t is int32_t:
        return ismember_int32(arr, values)
    elif htfunc_t is int64_t:
        return ismember_int64(arr, values)

    elif htfunc_t is uint8_t:
        return ismember_uint8(arr, values)
    elif htfunc_t is uint16_t:
        return ismember_uint16(arr, values)
    elif htfunc_t is uint32_t:
        return ismember_uint32(arr, values)
    elif htfunc_t is uint64_t:
        return ismember_uint64(arr, values)

    elif htfunc_t is float64_t:
        return ismember_float64(arr, values)
    elif htfunc_t is float32_t:
        return ismember_float32(arr, values)

    elif htfunc_t is complex128_t:
        return ismember_complex128(arr, values)
    elif htfunc_t is complex64_t:
        return ismember_complex64(arr, values)

    else:
        raise TypeError(values.dtype)


@cython.wraparound(False)
@cython.boundscheck(False)
def mode(ndarray[htfunc_t] values, bint dropna, const uint8_t[:] mask=None):
    # TODO(cython3): use const htfunct_t[:]

    cdef:
        ndarray[htfunc_t] keys
        ndarray[htfunc_t] modes

        int64_t[::1] counts
        int64_t count, max_count = -1
        Py_ssize_t nkeys, k, j = 0

    keys, counts = value_count(values, dropna, mask=mask)
    nkeys = len(keys)

    modes = np.empty(nkeys, dtype=values.dtype)

    if htfunc_t is not object:
        with nogil:
            for k in range(nkeys):
                count = counts[k]
                if count == max_count:
                    j += 1
                elif count > max_count:
                    max_count = count
                    j = 0
                else:
                    continue

                modes[j] = keys[k]
    else:
        for k in range(nkeys):
            count = counts[k]
            if count == max_count:
                j += 1
            elif count > max_count:
                max_count = count
                j = 0
            else:
                continue

            modes[j] = keys[k]

    return modes[:j + 1]


{{py:

# name, dtype, ttype, c_type
dtypes = [('Int64', 'int64', 'int64', 'int64_t'),
          ('Int32', 'int32', 'int32', 'int32_t'), ]

}}

{{for name, dtype, ttype, c_type in dtypes}}


@cython.wraparound(False)
@cython.boundscheck(False)
def _unique_label_indices_{{dtype}}(const {{c_type}}[:] labels) -> ndarray:
    """
    Indices of the first occurrences of the unique labels
    *excluding* -1. equivalent to:
        np.unique(labels, return_index=True)[1]
    """
    cdef:
        int ret = 0
        Py_ssize_t i, n = len(labels)
        kh_{{ttype}}_t *table = kh_init_{{ttype}}()
        {{name}}Vector idx = {{name}}Vector()
        ndarray[{{c_type}}, ndim=1] arr
        {{name}}VectorData *ud = idx.data

    kh_resize_{{ttype}}(table, min(kh_needed_n_buckets(n), SIZE_HINT_LIMIT))

    with nogil:
        for i in range(n):
            kh_put_{{ttype}}(table, labels[i], &ret)
            if ret != 0:
                if needs_resize(ud):
                    with gil:
                        idx.resize()
                append_data_{{ttype}}(ud, i)

    kh_destroy_{{ttype}}(table)

    arr = idx.to_array()
    arr = arr[np.asarray(labels)[arr].argsort()]

    return arr[1:] if arr.size != 0 and labels[arr[0]] == -1 else arr

{{endfor}}