← 泛型概览 | 泛型类型 →

泛型函数 - 类型参数函数

泛型函数是带有类型参数的函数,可以适用于多种类型。本节深入讲解泛型函数的定义、约束、推断以及实际应用场景。

泛型函数核心模式

1. 转换函数

📝 类型转换与映射

package main

import (
    "fmt"
    "strconv"
)

// Map: 将一种类型转换为另一种
func Map[T, U any](slice []T, fn func(T) U) []U {
    result := make([]U, len(slice))
    for i, v := range slice {
        result[i] = fn(v)
    }
    return result
}

// Filter: 过滤满足条件的元素
func Filter[T any](slice []T, predicate func(T) bool) []T {
    result := []T{}
    for _, v := range slice {
        if predicate(v) {
            result = append(result, v)
        }
    }
    return result
}

// Reduce: 归约操作
func Reduce[T any, R any](slice []T, initial R, fn func(R, T) R) R {
    result := initial
    for _, v := range slice {
        result = fn(result, v)
    }
    return result
}

func main() {
    nums := []int{1, 2, 3, 4, 5}
    
    // Map: int → string
    strs := Map(nums, func(n int) string {
        return strconv.Itoa(n)
    })
    fmt.Println(strs) // [1 2 3 4 5]
    
    // Filter: 只保留偶数
    evens := Filter(nums, func(n int) bool {
        return n%2 == 0
    })
    fmt.Println(evens) // [2 4]
    
    // Reduce: 求和
    sum := Reduce(nums, 0, func(acc, n int) int {
        return acc + n
    })
    fmt.Println(sum) // 15
}

2. 数值计算函数

📝 泛型数值运算

package main

import (
    "fmt"
    "golang.org/x/exp/constraints"
)

// Sum: 数值类型求和
func Sum[T constraints.Numeric](nums ...T) T {
    var sum T
    for _, n := range nums {
        sum += n
    }
    return sum
}

// Average: 计算平均值
func Average[T constraints.Float | constraints.Integer](nums []T) float64 {
    if len(nums) == 0 {
        return 0
    }
    sum := Sum(nums...)
    return float64(sum) / float64(len(nums))
}

// Min/Max: 查找最值
func Min[T constraints.Ordered](nums ...T) T {
    if len(nums) == 0 {
        var zero T
        return zero
    }
    min := nums[0]
    for _, n := range nums[1:] {
        if n < min {
            min = n
        }
    }
    return min
}

func Max[T constraints.Ordered](nums ...T) T {
    if len(nums) == 0 {
        var zero T
        return zero
    }
    max := nums[0]
    for _, n := range nums[1:] {
        if n > max {
            max = n
        }
    }
    return max
}

func main() {
    // 整数运算
    fmt.Println(Sum(1, 2, 3, 4, 5))        // 15
    fmt.Println(Average([]int{1, 2, 3, 4, 5})) // 3.0
    fmt.Println(Min(10, 5, 8, 3, 9))         // 3
    fmt.Println(Max(10, 5, 8, 3, 9))         // 10
    
    // 浮点数运算
    fmt.Println(Sum(1.5, 2.5, 3.0))         // 7.0
    fmt.Println(Min(3.14, 2.71, 1.41))      // 1.41
    
    // 字符串比较
    fmt.Println(Min("banana", "apple", "cherry")) // "apple"
}

3. 集合操作函数

📝 泛型集合操作

package main

import "fmt"

// 泛型集合
type Set[T comparable] map[T]struct{}

func NewSet[T comparable]() Set[T] {
    return make(Set[T])
}

func (s Set[T]) Add(v T) {
    s[v] = struct{}{}
}

func (s Set[T]) Has(v T) bool {
    _, ok := s[v]
    return ok
}

func (s Set[T]) Delete(v T) {
    delete(s, v)
}

func (s Set[T]) Size() int {
    return len(s)
}

// 集合运算
func Union[T comparable](sets ...Set[T]) Set[T] {
    result := NewSet[T]()
    for _, s := range sets {
        for v := range s {
            result.Add(v)
        }
    }
    return result
}

func Intersect[T comparable](sets ...Set[T]) Set[T] {
    if len(sets) == 0 {
        return NewSet[T]()
    }
    
    result := NewSet[T]()
    for v := range sets[0] {
        allHave := true
        for _, s := range sets[1:] {
            if !s.Has(v) {
                allHave = false
                break
            }
        }
        if allHave {
            result.Add(v)
        }
    }
    return result
}

func main() {
    s1 := NewSet[int]()
    s1.Add(1)
    s1.Add(2)
    s1.Add(3)
    
    s2 := NewSet[int]()
    s2.Add(2)
    s2.Add(3)
    s2.Add(4)
    
    fmt.Println(Union(s1, s2).Size())         // 4
    fmt.Println(Intersect(s1, s2).Size())     // 2
}

高级技巧

方法集上的泛型

📝 泛型方法

package main

import "fmt"

// 泛型结构体
type Container[T any] struct {
    items []T
}

func NewContainer[T any]() *Container[T] {
    return &Container[T]{items: make([]T, 0)}
}

// 泛型方法
func (c *Container[T]) Add(item T) {
    c.items = append(c.items, item)
}

func (c *Container[T]) Get(index int) (T, bool) {
    var zero T
    if index < 0 || index >= len(c.items) {
        return zero, false
    }
    return c.items[index], true
}

func (c *Container[T]) ForEach(fn func(T)) {
    for _, item := range c.items {
        fn(item)
    }
}

// 方法也可以是泛型的
func (c *Container[T]) Transform[U any](fn func(T) U) *Container[U] {
    result := NewContainer[U]()
    for _, item := range c.items {
        result.Add(fn(item))
    }
    return result
}

func main() {
    // 整数容器
    intContainer := NewContainer[int]()
    intContainer.Add(1)
    intContainer.Add(2)
    intContainer.Add(3)
    
    if val, ok := intContainer.Get(1); ok {
        fmt.Println(val) // 2
    }
    
    // Transform: int → string
    strContainer := intContainer.Transform(func(n int) string {
        return fmt.Sprintf("num-%d", n)
    })
    
    strContainer.ForEach(func(s string) {
        fmt.Println(s)
    })
}

泛型与接口结合

📝 泛型接口

package main

import "fmt"

// 泛型接口
type Processor[T any] interface {
    Process(T) T
}

// 实现泛型接口
type Doubler[T constraints.Numeric] struct{}

func (d Doubler[T]) Process(v T) T {
    return v * 2
}

type UpperCaser struct{}

func (u UpperCaser) Process(s string) string {
    return strings.ToUpper(s)
}

// 使用泛型接口
func ProcessAll[T any](processor Processor[T], items []T) []T {
    results := make([]T, len(items))
    for i, item := range items {
        results[i] = processor.Process(item)
    }
    return results
}

func main() {
    // 数值处理
    doubler := Doubler[int]{}
    nums := []int{1, 2, 3}
    doubled := ProcessAll(doubler, nums)
    fmt.Println(doubled) // [2 4 6]
}

陷阱与注意事项

⚠️ 常见陷阱

// 陷阱 1: 零值问题
func First[T any](slice []T) T {
    if len(slice) == 0 {
        var zero T  // 返回零值
        return zero
    }
    return slice[0]
}
// 问题:调用者无法区分空切片和包含零值的切片
// 解决:返回 (T, bool)

// 陷阱 2: 方法接收者类型
type MyType[T any] struct{}

// ✅ 正确:接收者带类型参数
func (m MyType[T]) Method(v T) {}

// ❌ 错误:方法不能有额外的类型参数
// func (m MyType[T]) Method[U any](v U) {}

// 陷阱 3: 泛型不能用于某些场景
// - 不能用于嵌入字段
// - 不能用于别名
// - 不能用于方法参数

// 陷阱 4: 类型推断失败
func Make[T any]() []T {
    return make([]T, 0)
}
// Make()  // ❌ 无法推断 T
// Make[int]()  // ✅ 必须显式指定

总结

✅ 核心要点

  • Map/Filter/Reduce: 函数式编程核心模式
  • 数值约束: constraints.Numeric/Ordered
  • 集合操作: 泛型 Set 实现
  • 泛型方法: 结构体方法可以是泛型
  • 类型推断: 优先让编译器推断类型
  • 零值处理: 注意空情况的返回值