跳到主要内容

第7章:面向对象编程高级特性 - 让对象拥有"超能力"

学习目标:掌握Python面向对象的高级特性,包括魔术方法、属性管理、描述符、抽象基类和设计模式,能够设计出功能强大、扩展性好的类体系。

想象一下,如果你的对象可以像内置类型一样支持 +-== 等运算符,可以自动验证属性值,可以像容器一样被迭代和索引,那该有多酷?这就是Python面向对象高级特性的魅力——让你的自定义类拥有"超能力"!


🎭 第7.1节:魔术方法与运算符重载 - 给对象添加"魔法"

什么是魔术方法?

魔术方法(Magic Methods),也叫特殊方法或双下划线方法,是Python中以双下划线开头和结尾的特殊方法。它们定义了对象的行为,让你的自定义类可以像内置类型一样工作。

就像给超级英雄赋予超能力一样,魔术方法给普通对象赋予了特殊能力:

# 没有魔术方法的普通类
class SimpleVector:
def __init__(self, x, y):
self.x = x
self.y = y
# 使用时很不方便
v1 = SimpleVector(1, 2)
v2 = SimpleVector(3, 4)
print(v1) # 输出:<__main__.SimpleVector object at 0x...> 看不懂!
# v3 = v1 + v2 # 报错!不支持加法
# 有魔术方法的"超能力"类
class MagicVector:
def __init__(self, x, y):
self.x = x
self.y = y
def __str__(self):
return f"Vector({self.x}, {self.y})"
def __add__(self, other):
return MagicVector(self.x + other.x, self.y + other.y)
# 使用起来就像内置类型一样自然
v1 = MagicVector(1, 2)
v2 = MagicVector(3, 4)
print(v1) # 输出:Vector(1, 2) 清楚明了!
v3 = v1 + v2 # 可以直接相加!
print(v3) # 输出:Vector(4, 6)

对象生命周期魔术方法

1. __new__ vs __init__:对象的"出生"过程

很多人以为 __init__ 是构造函数,其实不然。对象的创建分为两步:

class Person:
def __new__(cls, name, age):
print(f"Step 1: __new__ 正在创建 {name} 的实例")
# __new__ 负责创建实例
instance = super().__new__(cls)
return instance
def __init__(self, name, age):
print(f"Step 2: __init__ 正在初始化 {name}")
# __init__ 负责初始化实例
self.name = name
self.age = age
# 创建对象时会先调用 __new__ 再调用 __init__
person = Person("张三", 25)
# 输出:
# Step 1: __new__ 正在创建 张三 的实例
# Step 2: __init__ 正在初始化 张三

__new__ 的实际应用 - 单例模式

class DatabaseConnection:
_instance = None
def __new__(cls):
# 单例模式:只创建一个实例
if cls._instance is None:
print("创建数据库连接...")
cls._instance = super().__new__(cls)
else:
print("使用现有数据库连接...")
return cls._instance
def __init__(self):
if not hasattr(self, 'initialized'):
self.host = "localhost"
self.port = 3306
self.initialized = True
# 测试单例效果
db1 = DatabaseConnection() # 创建数据库连接...
db2 = DatabaseConnection() # 使用现有数据库连接...
print(db1 is db2) # True,是同一个实例

2. __del__:对象的"葬礼"

class FileManager:
def __init__(self, filename):
self.filename = filename
self.file = open(filename, 'w')
print(f"打开文件 {filename}")
def __del__(self):
# 对象被销毁时自动调用
if hasattr(self, 'file') and not self.file.closed:
self.file.close()
print(f"自动关闭文件 {self.filename}")
# 测试自动资源管理
manager = FileManager("test.txt")
manager = None # 删除引用,触发 __del__
# 输出:自动关闭文件 test.txt

字符串表示魔术方法

__str__ vs __repr__:两种"自我介绍"方式

class Student:
def __init__(self, name, grade, student_id):
self.name = name
self.grade = grade
self.student_id = student_id
def __str__(self):
# 给用户看的友好表示
return f"{self.name}同学({self.grade}年级)"
def __repr__(self):
# 给开发者看的精确表示,最好能重现对象
return f"Student('{self.name}', {self.grade}, '{self.student_id}')"
student = Student("李雷", 3, "2021001")
print(str(student)) # 李雷同学(3年级)
print(repr(student)) # Student('李雷', 3, '2021001')
# 在列表中显示时使用 __repr__
students = [student]
print(students) # [Student('李雷', 3, '2021001')]

最佳实践

  • __str__:写给用户看的,要友好易读
  • __repr__:写给开发者看的,要精确明确,最好能用来重建对象

__format__:自定义格式化

class Money:
def __init__(self, amount):
self.amount = amount
def __format__(self, format_spec):
if format_spec == 'cn':
return f"¥{self.amount:.2f}"
elif format_spec == 'us':
return f"${self.amount:.2f}"
elif format_spec == 'int':
return f"{int(self.amount)}"
else:
return f"{self.amount:.2f}"
money = Money(1234.567)
print(f"中文格式:{money:cn}") # 中文格式:¥1234.57
print(f"美式格式:{money:us}") # 美式格式:$1234.57
print(f"整数格式:{money:int}") # 整数格式:1234

比较运算符重载

class Grade:
def __init__(self, score):
self.score = score
def __eq__(self, other):
"""相等比较 =="""
return self.score == other.score
def __lt__(self, other):
"""小于比较 <"""
return self.score < other.score
def __le__(self, other):
"""小于等于 <="""
return self.score <= other.score
def __gt__(self, other):
"""大于比较 >"""
return self.score > other.score
def __ge__(self, other):
"""大于等于 >="""
return self.score >= other.score
def __ne__(self, other):
"""不等于 !="""
return self.score != other.score
def __str__(self):
return f"Grade({self.score})"
# 测试比较操作
grade1 = Grade(85)
grade2 = Grade(92)
grade3 = Grade(85)
print(f"{grade1} == {grade3}: {grade1 == grade3}") # True
print(f"{grade1} < {grade2}: {grade1 < grade2}") # True
print(f"{grade1} > {grade2}: {grade1 > grade2}") # False
# 可以排序了!
grades = [Grade(78), Grade(92), Grade(85), Grade(67)]
sorted_grades = sorted(grades)
print("排序后的成绩:", [str(g) for g in sorted_grades])
# 输出:排序后的成绩: ['Grade(67)', 'Grade(78)', 'Grade(85)', 'Grade(92)']
```
**技巧**:Python 可以从 `__eq__` 和 `__lt__` 自动推导其他比较方法,使用 `functools.total_ordering` 装饰器:
from functools import total_ordering
@total_ordering
class SimpleGrade:
def __init__(self, score):
self.score = score
def __eq__(self, other):
return self.score == other.score
def __lt__(self, other):
return self.score < other.score
def __str__(self):
return f"Grade({self.score})"
# 只定义了 __eq__ 和 __lt__,但其他比较方法也能工作!
grade1 = SimpleGrade(85)
grade2 = SimpleGrade(92)
print(grade1 <= grade2) # True,自动推导出来的

算术运算符重载

让我们创建一个功能完整的向量类:

class Vector2D:
def __init__(self, x, y):
self.x = x
self.y = y
# 基本算术运算
def __add__(self, other):
"""向量加法"""
return Vector2D(self.x + other.x, self.y + other.y)
def __sub__(self, other):
"""向量减法"""
return Vector2D(self.x - other.x, self.y - other.y)
def __mul__(self, scalar):
"""向量数乘"""
if isinstance(scalar, (int, float)):
return Vector2D(self.x * scalar, self.y * scalar)
else:
raise TypeError("向量只能与数字相乘")
def __rmul__(self, scalar):
"""反向乘法:3 * vector"""
return self.__mul__(scalar)
def __truediv__(self, scalar):
"""向量除法"""
if isinstance(scalar, (int, float)) and scalar != 0:
return Vector2D(self.x / scalar, self.y / scalar)
else:
raise ValueError("除数必须是非零数字")
# 增强赋值运算符
def __iadd__(self, other):
"""+="""
self.x += other.x
self.y += other.y
return self
def __isub__(self, other):
"""-="""
self.x -= other.x
self.y -= other.y
return self
# 一元运算符
def __neg__(self):
"""负号 -vector"""
return Vector2D(-self.x, -self.y)
def __abs__(self):
"""绝对值(向量长度)"""
return (self.x ** 2 + self.y ** 2) ** 0.5
# 比较运算(按长度比较)
def __eq__(self, other):
return abs(self) == abs(other)
def __lt__(self, other):
return abs(self) < abs(other)
def __str__(self):
return f"Vector2D({self.x}, {self.y})"
def __repr__(self):
return f"Vector2D(x={self.x}, y={self.y})"
# 测试完整的向量运算
v1 = Vector2D(3, 4)
v2 = Vector2D(1, 2)
print(f"v1 = {v1}") # v1 = Vector2D(3, 4)
print(f"v2 = {v2}") # v2 = Vector2D(1, 2)
print(f"v1 + v2 = {v1 + v2}") # v1 + v2 = Vector2D(4, 6)
print(f"v1 - v2 = {v1 - v2}") # v1 - v2 = Vector2D(2, 2)
print(f"v1 * 2 = {v1 * 2}") # v1 * 2 = Vector2D(6, 8)
print(f"3 * v1 = {3 * v1}") # 3 * v1 = Vector2D(9, 12)
print(f"-v1 = {-v1}") # -v1 = Vector2D(-3, -4)
print(f"|v1| = {abs(v1)}") # |v1| = 5.0
# 增强赋值
v1 += v2
print(f"v1 += v2 后: {v1}") # v1 += v2 后: Vector2D(4, 6)

容器模拟魔术方法

让我们创建一个自定义的成绩册类,支持像列表一样的操作:

class GradeBook:
def __init__(self):
self._grades = {} # {学生姓名: 成绩}
def __len__(self):
"""支持 len() 函数"""
return len(self._grades)
def __getitem__(self, name):
"""支持 [] 取值"""
if name in self._grades:
return self._grades[name]
else:
raise KeyError(f"学生 {name} 不存在")
def __setitem__(self, name, grade):
"""支持 [] 赋值"""
if not isinstance(grade, (int, float)) or not 0 <= grade <= 100:
raise ValueError("成绩必须是 0-100 之间的数字")
self._grades[name] = grade
def __delitem__(self, name):
"""支持 del 删除"""
if name in self._grades:
del self._grades[name]
else:
raise KeyError(f"学生 {name} 不存在")
def __contains__(self, name):
"""支持 in 操作符"""
return name in self._grades
def __iter__(self):
"""支持 for 循环遍历"""
return iter(self._grades.items())
def __str__(self):
return f"GradeBook({len(self)} students)"
def __repr__(self):
return f"GradeBook({dict(self._grades)})"
# 测试容器功能
gradebook = GradeBook()
# 添加成绩(支持 [] 赋值)
gradebook["张三"] = 85
gradebook["李四"] = 92
gradebook["王五"] = 78
# 获取成绩(支持 [] 取值)
print(f"张三的成绩:{gradebook['张三']}") # 张三的成绩:85
# 检查学生是否存在(支持 in)
print("张三" in gradebook) # True
print("赵六" in gradebook) # False
# 获取学生数量(支持 len)
print(f"学生数量:{len(gradebook)}") # 学生数量:3
# 遍历所有学生(支持 for 循环)
print("所有学生成绩:")
for name, grade in gradebook:
print(f" {name}: {grade}")
# 删除学生(支持 del)
del gradebook["王五"]
print(f"删除王五后学生数量:{len(gradebook)}") # 删除王五后学生数量:2

🏃‍♂️ 实践练习:智能计算器类

练习1:创建一个 Calculator 类,支持基本数学运算和比较。

要求

  1. 支持 +-*/ 四则运算
  2. 支持比较运算符
  3. 支持格式化输出(整数、小数、百分比)
  4. 记录运算历史
class Calculator:
def __init__(self, value=0):
self.value = value
self.history = [] # 运算历史
def _record(self, operation, other, result):
"""记录运算历史"""
self.history.append(f"{self.value} {operation} {other} = {result}")
def __add__(self, other):
if isinstance(other, Calculator):
result = self.value + other.value
else:
result = self.value + other
self._record('+', other, result)
return Calculator(result)
def __sub__(self, other):
# TODO: 实现减法
pass
def __mul__(self, other):
# TODO: 实现乘法
pass
def __truediv__(self, other):
# TODO: 实现除法
pass
def __eq__(self, other):
# TODO: 实现相等比较
pass
def __format__(self, format_spec):
# TODO: 实现格式化
# 'int': 整数格式
# 'float': 小数格式
# 'percent': 百分比格式
pass
def __str__(self):
return str(self.value)
# 测试代码
calc1 = Calculator(10)
calc2 = Calculator(3)
result = calc1 + calc2
print(f"计算结果:{result}") # 计算结果:13

本节小结

魔术方法让你的自定义类具备了"超能力":

  1. 对象生命周期__new____init____del__
  2. 字符串表示__str____repr____format__
  3. 比较运算__eq____lt____le__
  4. 算术运算__add____sub____mul__
  5. 容器模拟__len____getitem____setitem__

下一节我们将学习属性管理和描述符,让对象的属性变得更加智能!


🎛️ 第7.2节:属性管理与描述符 - 让属性变得"智能"

为什么需要智能属性?

在现实世界中,有些属性不能随意设置。比如:

  • 人的年龄不能是负数
  • 温度不能低于绝对零度(-273.15°C)
  • 银行账户余额变化需要记录日志
  • 某些属性需要根据其他属性自动计算

传统的做法是写大量的 getter 和 setter 方法,但 Python 提供了更优雅的解决方案。

property 装饰器:属性的"守门员"

基础用法:只读属性

class Circle:
def __init__(self, radius):
self._radius = radius # 使用下划线表示"内部"属性
@property
def radius(self):
"""半径属性(只读)"""
return self._radius
@property
def area(self):
"""面积属性(只读,自动计算)"""
return 3.14159 * self._radius ** 2
@property
def circumference(self):
"""周长属性(只读,自动计算)"""
return 2 * 3.14159 * self._radius
# 使用像普通属性一样
circle = Circle(5)
print(f"半径:{circle.radius}") # 半径:5
print(f"面积:{circle.area:.2f}") # 面积:78.54
print(f"周长:{circle.circumference:.2f}") # 周长:31.42
# 试图修改只读属性会报错
# circle.area = 100 # AttributeError: can't set attribute

完整用法:getter、setter、deleter

class Temperature:
def __init__(self, celsius=0):
self._celsius = celsius
self._change_log = [] # 记录温度变化
@property
def celsius(self):
"""摄氏温度 getter"""
return self._celsius
@celsius.setter
def celsius(self, value):
"""摄氏温度 setter,带验证"""
if not isinstance(value, (int, float)):
raise TypeError("温度必须是数字")
if value < -273.15:
raise ValueError("温度不能低于绝对零度(-273.15°C)")
# 记录变化
old_value = self._celsius
self._celsius = value
self._change_log.append(f"温度从 {old_value}°C 变为 {value}°C")
@celsius.deleter
def celsius(self):
"""重置温度"""
self._celsius = 0
self._change_log.append("温度被重置为 0°C")
# 华氏温度属性(自动转换)
@property
def fahrenheit(self):
"""华氏温度 getter"""
return self._celsius * 9/5 + 32
@fahrenheit.setter
def fahrenheit(self, value):
"""华氏温度 setter,自动转换为摄氏度"""
if not isinstance(value, (int, float)):
raise TypeError("温度必须是数字")
celsius_value = (value - 32) * 5/9
self.celsius = celsius_value # 使用摄氏度的 setter,包含验证逻辑
@property
def kelvin(self):
"""开尔文温度"""
return self._celsius + 273.15
@property
def change_log(self):
"""温度变化日志"""
return self._change_log.copy()
# 测试温度类
temp = Temperature(25)
print(f"初始温度:{temp.celsius}°C")
print(f"华氏温度:{temp.fahrenheit}°F")
print(f"开尔文温度:{temp.kelvin}K")
# 修改温度
temp.celsius = 30
print(f"修改后温度:{temp.celsius}°C")
# 通过华氏度设置温度
temp.fahrenheit = 86 # 30°C
print(f"通过华氏度设置:{temp.celsius}°C")
# 查看变化日志
print("温度变化记录:")
for log in temp.change_log:
print(f" {log}")
# 删除属性(重置)
del temp.celsius
print(f"重置后温度:{temp.celsius}°C")
# 验证机制测试
try:
temp.celsius = -300 # 低于绝对零度
except ValueError as e:
print(f"错误:{e}")

缓存属性:提高性能

class ExpensiveCalculation:
def __init__(self, data):
self.data = data
self._result_cache = None
@property
def expensive_result(self):
"""昂贵的计算,使用缓存优化"""
if self._result_cache is None:
print("正在进行昂贵的计算...")
# 模拟复杂计算
self._result_cache = sum(x**2 for x in self.data)
return self._result_cache
def clear_cache(self):
"""清除缓存"""
self._result_cache = None
# 测试缓存效果
calc = ExpensiveCalculation([1, 2, 3, 4, 5])
print(calc.expensive_result) # 正在进行昂贵的计算... 55
print(calc.expensive_result) # 55(直接从缓存返回,不再计算)

描述符协议:属性管理的"终极武器"

描述符是实现了描述符协议的类,可以控制对属性的访问。这是 property 装饰器的底层实现机制。

描述符协议的三个方法

class MyDescriptor:
def __get__(self, obj, objtype=None):
"""获取属性时调用"""
print(f"__get__ 被调用:obj={obj}, objtype={objtype}")
return "描述符的值"
def __set__(self, obj, value):
"""设置属性时调用"""
print(f"__set__ 被调用:obj={obj}, value={value}")
def __delete__(self, obj):
"""删除属性时调用"""
print(f"__delete__ 被调用:obj={obj}")
class TestClass:
attr = MyDescriptor() # 类属性
# 测试描述符
test = TestClass()
print(test.attr) # 触发 __get__
test.attr = "新值" # 触发 __set__
del test.attr # 触发 __delete__

实用的描述符:类型验证器

class TypeValidator:
"""类型验证描述符"""
def __init__(self, expected_type, default=None):
self.expected_type = expected_type
self.default = default
self.name = None
def __set_name__(self, owner, name):
"""Python 3.6+ 新特性,自动获取属性名"""
self.name = name
def __get__(self, obj, objtype=None):
if obj is None:
return self
return obj.__dict__.get(self.name, self.default)
def __set__(self, obj, value):
if not isinstance(value, self.expected_type):
raise TypeError(
f"{self.name} 必须是 {self.expected_type.__name__} 类型,"
f"得到的是 {type(value).__name__}"
)
obj.__dict__[self.name] = value
class RangeValidator(TypeValidator):
"""范围验证描述符"""
def __init__(self, expected_type, min_val=None, max_val=None, default=None):
super().__init__(expected_type, default)
self.min_val = min_val
self.max_val = max_val
def __set__(self, obj, value):
# 先进行类型验证
super().__set__(obj, value)
# 再进行范围验证
if self.min_val is not None and value < self.min_val:
raise ValueError(f"{self.name} 不能小于 {self.min_val}")
if self.max_val is not None and value > self.max_val:
raise ValueError(f"{self.name} 不能大于 {self.max_val}")
# 使用描述符创建学生类
class Student:
# 使用描述符定义属性,自动进行类型和范围验证
name = TypeValidator(str, "")
age = RangeValidator(int, min_val=0, max_val=150, default=18)
grade = RangeValidator(float, min_val=0.0, max_val=100.0, default=0.0)
def __init__(self, name, age, grade):
self.name = name # 触发 TypeValidator.__set__
self.age = age # 触发 RangeValidator.__set__
self.grade = grade # 触发 RangeValidator.__set__
def __str__(self):
return f"Student(name='{self.name}', age={self.age}, grade={self.grade})"
# 测试描述符验证
try:
student1 = Student("张三", 20, 85.5)
print(student1) # 正常创建
student2 = Student("李四", -5, 90) # 年龄验证失败
except ValueError as e:
print(f"验证错误:{e}")
try:
student1.grade = 105 # 成绩验证失败
except ValueError as e:
print(f"验证错误:{e}")
try:
student1.name = 123 # 类型验证失败
except TypeError as e:
print(f"类型错误:{e}")

计算属性描述符

class ComputedProperty:
"""计算属性描述符"""
def __init__(self, func):
self.func = func
self.name = func.__name__
def __get__(self, obj, objtype=None):
if obj is None:
return self
# 每次访问都重新计算
return self.func(obj)
def __set__(self, obj, value):
raise AttributeError(f"'{self.name}' 是只读计算属性")
class Rectangle:
def __init__(self, width, height):
self.width = width
self.height = height
@ComputedProperty
def area(self):
"""面积(计算属性)"""
print("计算面积...")
return self.width * self.height
@ComputedProperty
def perimeter(self):
"""周长(计算属性)"""
print("计算周长...")
return 2 * (self.width + self.height)
# 测试计算属性
rect = Rectangle(5, 3)
print(f"面积:{rect.area}") # 计算面积... 面积:15
print(f"周长:{rect.perimeter}") # 计算周长... 周长:16
# 修改尺寸后重新计算
rect.width = 10
print(f"新面积:{rect.area}") # 计算面积... 新面积:30

属性访问控制魔术方法

__getattr____setattr____delattr__

class DynamicObject:
"""支持动态属性的对象"""
def __init__(self):
# 必须使用 super().__setattr__ 避免无限递归
super().__setattr__('_data', {})
super().__setattr__('_access_log', [])
def __getattr__(self, name):
"""当访问不存在的属性时调用"""
self._access_log.append(f"访问属性:{name}")
if name in self._data:
return self._data[name]
else:
raise AttributeError(f"'{type(self).__name__}' 对象没有属性 '{name}'")
def __setattr__(self, name, value):
"""设置任何属性时调用"""
if name.startswith('_'):
# 内部属性直接设置
super().__setattr__(name, value)
else:
# 用户属性存储在 _data 中
if not hasattr(self, '_data'):
super().__setattr__('_data', {})
if not hasattr(self, '_access_log'):
super().__setattr__('_access_log', [])
self._access_log.append(f"设置属性:{name} = {value}")
self._data[name] = value
def __delattr__(self, name):
"""删除属性时调用"""
if name.startswith('_'):
super().__delattr__(name)
else:
if name in self._data:
self._access_log.append(f"删除属性:{name}")
del self._data[name]
else:
raise AttributeError(f"属性 '{name}' 不存在")
def get_access_log(self):
"""获取访问日志"""
return self._access_log.copy()
# 测试动态属性
obj = DynamicObject()
# 动态设置属性
obj.name = "张三"
obj.age = 25
obj.city = "北京"
# 访问属性
print(f"姓名:{obj.name}")
print(f"年龄:{obj.age}")
# 删除属性
del obj.city
# 查看访问日志
print("访问日志:")
for log in obj.get_access_log():
print(f" {log}")

🏃‍♂️ 实践练习:智能配置管理器

练习2:创建一个配置管理器类,支持:

  1. 自动类型转换和验证
  2. 配置项的历史记录
  3. 嵌套配置访问
  4. 配置的保存和加载
class ConfigValidator:
"""配置验证描述符"""
def __init__(self, expected_type, default=None, validator=None):
self.expected_type = expected_type
self.default = default
self.validator = validator # 自定义验证函数
self.name = None
def __set_name__(self, owner, name):
self.name = name
def __get__(self, obj, objtype=None):
if obj is None:
return self
return obj._config.get(self.name, self.default)
def __set__(self, obj, value):
# TODO: 实现类型转换、验证和历史记录
pass
class ConfigManager:
# 定义配置项
host = ConfigValidator(str, "localhost")
port = ConfigValidator(int, 8080, lambda x: 1 <= x <= 65535)
debug = ConfigValidator(bool, False)
timeout = ConfigValidator(float, 30.0, lambda x: x > 0)
def __init__(self):
self._config = {}
self._history = {}
def get_history(self, key):
"""获取配置项的历史记录"""
return self._history.get(key, [])
def save_config(self, filename):
"""保存配置到文件"""
# TODO: 实现配置保存
pass
def load_config(self, filename):
"""从文件加载配置"""
# TODO: 实现配置加载
pass
# 测试代码
config = ConfigManager()
config.host = "192.168.1.1"
config.port = "9000" # 字符串自动转换为整数
config.debug = "true" # 字符串自动转换为布尔值
print(f"主机:{config.host}")
print(f"端口:{config.port}")
print(f"调试模式:{config.debug}")

本节小结

属性管理和描述符为我们提供了强大的工具:

  1. property 装饰器:创建智能属性,支持验证、计算、缓存
  2. 描述符协议:底层的属性控制机制,可重用的属性管理器
  3. 属性访问控制:拦截属性访问,实现动态属性和日志记录

这些特性让我们能够创建更加健壮、智能的类,为复杂应用奠定基础。

下一节我们将学习高级类特性,包括抽象基类、类方法、静态方法和元类编程!


🏛️ 第7.3节:高级类特性 - 构建强大的类体系

类方法与静态方法:选择合适的工具

在Python中,除了普通的实例方法,还有类方法和静态方法。理解它们的区别和使用场景是设计良好类体系的关键。

实例方法 vs 类方法 vs 静态方法

class Student:
# 类变量
school_name = "Python大学"
total_students = 0
def __init__(self, name, age):
self.name = name
self.age = age
Student.total_students += 1
# 实例方法:需要访问实例数据
def introduce(self):
"""实例方法,访问 self"""
return f"我是{self.name},今年{self.age}岁"
# 类方法:需要访问类数据或创建实例
@classmethod
def get_school_info(cls):
"""类方法,访问 cls(类本身)"""
return f"学校:{cls.school_name},总学生数:{cls.total_students}"
@classmethod
def create_adult_student(cls, name):
"""类方法作为替代构造函数"""
return cls(name, 18) # 调用 __init__
# 静态方法:逻辑上属于这个类,但不需要访问类或实例数据
@staticmethod
def is_valid_age(age):
"""静态方法,纯工具函数"""
return 0 <= age <= 120
@staticmethod
def calculate_graduation_year(current_year, grade):
"""静态方法,计算毕业年份"""
return current_year + (4 - grade)
# 使用示例
student1 = Student("张三", 20)
student2 = Student("李四", 19)
# 实例方法调用
print(student1.introduce()) # 我是张三,今年20岁
# 类方法调用(两种方式)
print(Student.get_school_info()) # 学校:Python大学,总学生数:2
print(student1.get_school_info()) # 同样的结果
# 类方法作为替代构造函数
adult_student = Student.create_adult_student("王五")
print(adult_student.introduce()) # 我是王五,今年18岁
# 静态方法调用(三种方式)
print(Student.is_valid_age(25)) # True
print(student1.is_valid_age(-5)) # False
print(Student.calculate_graduation_year(2025, 2)) # 2026

选择指南

  • 实例方法:需要访问或修改实例状态时使用
  • 类方法:需要访问类变量,或作为替代构造函数时使用
  • 静态方法:逻辑上属于类,但不需要访问类或实例数据的工具函数

实际应用:数据库连接管理

import json
from datetime import datetime
class DatabaseConnection:
# 类变量:连接池配置
max_connections = 10
active_connections = 0
connection_history = []
def __init__(self, host, port, database):
if DatabaseConnection.active_connections >= DatabaseConnection.max_connections:
raise Exception("连接池已满")
self.host = host
self.port = port
self.database = database
self.connected_at = datetime.now()
DatabaseConnection.active_connections += 1
DatabaseConnection.connection_history.append({
'host': host,
'database': database,
'connected_at': self.connected_at.isoformat()
})
# 实例方法:操作具体连接
def execute_query(self, sql):
"""执行SQL查询"""
return f"在 {self.database} 执行:{sql}"
def close(self):
"""关闭连接"""
DatabaseConnection.active_connections -= 1
print(f"连接 {self.host}:{self.database} 已关闭")
# 类方法:管理连接池状态
@classmethod
def get_pool_status(cls):
"""获取连接池状态"""
return {
'active': cls.active_connections,
'max': cls.max_connections,
'available': cls.max_connections - cls.active_connections
}
@classmethod
def create_local_connection(cls, database):
"""创建本地数据库连接(替代构造函数)"""
return cls('localhost', 3306, database)
@classmethod
def create_from_config(cls, config_file):
"""从配置文件创建连接(替代构造函数)"""
with open(config_file, 'r') as f:
config = json.load(f)
return cls(config['host'], config['port'], config['database'])
# 静态方法:工具函数
@staticmethod
def validate_host(host):
"""验证主机地址格式"""
import re
pattern = r'^(\d{1,3}\.){3}\d{1,3}$|^localhost$|^[\w.-]+$'
return bool(re.match(pattern, host))
@staticmethod
def parse_connection_string(conn_str):
"""解析连接字符串"""
# 示例:mysql://user:pass@host:port/database
import re
pattern = r'(\w+)://(\w+):(\w+)@([^:]+):(\d+)/(\w+)'
match = re.match(pattern, conn_str)
if match:
return {
'type': match.group(1),
'user': match.group(2),
'password': match.group(3),
'host': match.group(4),
'port': int(match.group(5)),
'database': match.group(6)
}
return None
# 使用示例
# 验证主机地址(静态方法)
print(DatabaseConnection.validate_host("192.168.1.1")) # True
print(DatabaseConnection.validate_host("localhost")) # True
# 解析连接字符串(静态方法)
conn_info = DatabaseConnection.parse_connection_string(
"mysql://admin:123456@localhost:3306/testdb"
)
print(conn_info)
# 创建连接(类方法作为构造函数)
db1 = DatabaseConnection.create_local_connection("myapp")
db2 = DatabaseConnection.create_local_connection("logs")
# 查看连接池状态(类方法)
print(DatabaseConnection.get_pool_status()) # {'active': 2, 'max': 10, 'available': 8}
# 使用连接(实例方法)
result = db1.execute_query("SELECT * FROM users")
print(result)
# 关闭连接
db1.close()

抽象基类(ABC):定义接口契约

抽象基类用于定义接口规范,确保子类实现必要的方法。这是面向对象设计中的重要模式。

基础抽象基类

from abc import ABC, abstractmethod
class Shape(ABC):
"""形状抽象基类"""
def __init__(self, name):
self.name = name
@abstractmethod
def area(self):
"""计算面积(抽象方法,子类必须实现)"""
pass
@abstractmethod
def perimeter(self):
"""计算周长(抽象方法,子类必须实现)"""
pass
# 具体方法(子类可以直接使用)
def display_info(self):
"""显示形状信息"""
return f"{self.name}: 面积={self.area():.2f}, 周长={self.perimeter():.2f}"
@classmethod
def from_string(cls, shape_str):
"""从字符串创建形状(模板方法)"""
# 这是一个模板方法,定义了创建流程
parts = shape_str.split(',')
if len(parts) < 2:
raise ValueError("格式错误")
shape_type = parts[0].strip()
if shape_type == "circle":
radius = float(parts[1])
return Circle(radius)
elif shape_type == "rectangle":
width, height = float(parts[1]), float(parts[2])
return Rectangle(width, height)
else:
raise ValueError(f"未知的形状类型:{shape_type}")
class Circle(Shape):
def __init__(self, radius):
super().__init__("圆形")
if radius <= 0:
raise ValueError("半径必须大于0")
self.radius = radius
def area(self):
return 3.14159 * self.radius ** 2
def perimeter(self):
return 2 * 3.14159 * self.radius
class Rectangle(Shape):
def __init__(self, width, height):
super().__init__("矩形")
if width <= 0 or height <= 0:
raise ValueError("宽度和高度必须大于0")
self.width = width
self.height = height
def area(self):
return self.width * self.height
def perimeter(self):
return 2 * (self.width + self.height)
# 使用示例
# shape = Shape("抽象形状") # 报错!不能实例化抽象类
circle = Circle(5)
rectangle = Rectangle(4, 6)
print(circle.display_info()) # 圆形: 面积=78.54, 周长=31.42
print(rectangle.display_info()) # 矩形: 面积=24.00, 周长=20.00
# 使用类方法创建对象
shape1 = Shape.from_string("circle, 3")
shape2 = Shape.from_string("rectangle, 2, 8")
print(shape1.display_info()) # 圆形: 面积=28.27, 周长=18.85
print(shape2.display_info()) # 矩形: 面积=16.00, 周长=20.00

复杂抽象基类:数据处理器

from abc import ABC, abstractmethod
from typing import Any, List, Dict
class DataProcessor(ABC):
"""数据处理器抽象基类"""
def __init__(self, name: str):
self.name = name
self._processed_count = 0
@abstractmethod
def validate(self, data: Any) -> bool:
"""验证数据格式"""
pass
@abstractmethod
def process_single(self, data: Any) -> Any:
"""处理单条数据"""
pass
@abstractmethod
def get_output_format(self) -> str:
"""获取输出格式说明"""
pass
# 模板方法:定义处理流程
def process(self, data_list: List[Any]) -> List[Any]:
"""处理数据列表(模板方法)"""
results = []
for data in data_list:
try:
# 1. 验证数据
if not self.validate(data):
print(f"数据验证失败:{data}")
continue
# 2. 处理数据
result = self.process_single(data)
results.append(result)
self._processed_count += 1
except Exception as e:
print(f"处理数据时出错:{data}, 错误:{e}")
return results
def get_statistics(self) -> Dict[str, Any]:
"""获取处理统计信息"""
return {
'processor_name': self.name,
'processed_count': self._processed_count,
'output_format': self.get_output_format()
}
class TextProcessor(DataProcessor):
"""文本处理器"""
def __init__(self, transform_type="upper"):
super().__init__(f"文本处理器({transform_type})")
self.transform_type = transform_type
def validate(self, data):
return isinstance(data, str) and len(data.strip()) > 0
def process_single(self, data):
text = data.strip()
if self.transform_type == "upper":
return text.upper()
elif self.transform_type == "lower":
return text.lower()
elif self.transform_type == "title":
return text.title()
elif self.transform_type == "reverse":
return text[::-1]
else:
return text
def get_output_format(self):
return f"转换后的文本({self.transform_type})"
class NumberProcessor(DataProcessor):
"""数字处理器"""
def __init__(self, operation="square"):
super().__init__(f"数字处理器({operation})")
self.operation = operation
def validate(self, data):
try:
float(data)
return True
except (ValueError, TypeError):
return False
def process_single(self, data):
num = float(data)
if self.operation == "square":
return num ** 2
elif self.operation == "sqrt":
return num ** 0.5
elif self.operation == "abs":
return abs(num)
elif self.operation == "round":
return round(num, 2)
else:
return num
def get_output_format(self):
return f"处理后的数字({self.operation})"
# 使用示例
text_processor = TextProcessor("upper")
number_processor = NumberProcessor("square")
# 处理文本数据
texts = ["hello world", "python programming", " ", "abc123"]
processed_texts = text_processor.process(texts)
print("处理后的文本:", processed_texts)
print("文本处理统计:", text_processor.get_statistics())
# 处理数字数据
numbers = ["3.14", "25", "invalid", "-5", "0"]
processed_numbers = number_processor.process(numbers)
print("处理后的数字:", processed_numbers)
print("数字处理统计:", number_processor.get_statistics())

混入模式(Mixin):功能的"乐高积木"

混入模式允许我们将功能分解为小的、可重用的组件,然后通过多重继承组合它们。

基础混入示例

# 混入类:提供特定功能
class TimestampMixin:
"""时间戳混入"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
from datetime import datetime
self.created_at = datetime.now()
self.updated_at = datetime.now()
def touch(self):
"""更新时间戳"""
from datetime import datetime
self.updated_at = datetime.now()
def get_age(self):
"""获取对象年龄(秒)"""
from datetime import datetime
return (datetime.now() - self.created_at).total_seconds()
class ValidationMixin:
"""验证混入"""
def validate(self):
"""验证对象状态"""
errors = []
# 调用所有以 validate_ 开头的方法
for method_name in dir(self):
if method_name.startswith('validate_') and callable(getattr(self, method_name)):
try:
getattr(self, method_name)()
except ValueError as e:
errors.append(str(e))
if errors:
raise ValueError(f"验证失败:{'; '.join(errors)}")
return True
class SerializationMixin:
"""序列化混入"""
def to_dict(self):
"""转换为字典"""
result = {}
for key, value in self.__dict__.items():
if not key.startswith('_'):
if hasattr(value, 'isoformat'): # datetime对象
result[key] = value.isoformat()
else:
result[key] = value
return result
@classmethod
def from_dict(cls, data):
"""从字典创建对象"""
# 过滤掉时间戳字段,让TimestampMixin自动设置
filtered_data = {k: v for k, v in data.items()
if k not in ['created_at', 'updated_at']}
return cls(**filtered_data)
def to_json(self):
"""转换为JSON字符串"""
import json
return json.dumps(self.to_dict(), indent=2)
# 使用混入的具体类
class User(TimestampMixin, ValidationMixin, SerializationMixin):
"""用户类,组合多个混入功能"""
def __init__(self, username, email, age):
# 重要:调用父类的__init__以启用混入功能
super().__init__()
self.username = username
self.email = email
self.age = age
# 验证方法(被ValidationMixin调用)
def validate_username(self):
if not self.username or len(self.username) < 3:
raise ValueError("用户名长度至少3个字符")
def validate_email(self):
if '@' not in self.email:
raise ValueError("邮箱格式无效")
def validate_age(self):
if not isinstance(self.age, int) or self.age < 0:
raise ValueError("年龄必须是非负整数")
def __str__(self):
return f"User({self.username}, {self.email})"
# 另一个使用混入的类
class Product(TimestampMixin, ValidationMixin, SerializationMixin):
"""产品类"""
def __init__(self, name, price, category):
super().__init__()
self.name = name
self.price = price
self.category = category
def validate_name(self):
if not self.name or len(self.name.strip()) == 0:
raise ValueError("产品名称不能为空")
def validate_price(self):
if not isinstance(self.price, (int, float)) or self.price < 0:
raise ValueError("价格必须是非负数")
def __str__(self):
return f"Product({self.name}, ¥{self.price})"
# 使用示例
import time
# 创建用户
user = User("alice", "alice@example.com", 25)
print(f"创建用户:{user}")
# 验证功能
try:
user.validate()
print("用户验证通过")
except ValueError as e:
print(f"验证失败:{e}")
# 时间戳功能
print(f"用户创建时间:{user.created_at}")
time.sleep(1)
user.touch() # 更新时间戳
print(f"用户更新时间:{user.updated_at}")
print(f"用户年龄:{user.get_age():.2f}秒")
# 序列化功能
user_dict = user.to_dict()
print("用户字典:", user_dict)
user_json = user.to_json()
print("用户JSON:", user_json)
# 从字典重建对象
new_user = User.from_dict({'username': 'bob', 'email': 'bob@test.com', 'age': 30})
print(f"重建用户:{new_user}")
# 创建产品
product = Product("iPhone", 999.99, "electronics")
print(f"创建产品:{product}")
print("产品字典:", product.to_dict())

高级混入:缓存系统

import time
import hashlib
from functools import wraps
class CacheMixin:
"""缓存混入"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._cache = {}
self._cache_stats = {'hits': 0, 'misses': 0}
def _get_cache_key(self, method_name, args, kwargs):
"""生成缓存键"""
key_data = f"{method_name}:{args}:{sorted(kwargs.items())}"
return hashlib.md5(key_data.encode()).hexdigest()
def cached_method(self, expire_time=300):
"""缓存装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
cache_key = self._get_cache_key(func.__name__, args[1:], kwargs)
current_time = time.time()
# 检查缓存
if cache_key in self._cache:
cached_data, cached_time = self._cache[cache_key]
if current_time - cached_time < expire_time:
self._cache_stats['hits'] += 1
print(f"缓存命中:{func.__name__}")
return cached_data
# 缓存未命中,执行方法
self._cache_stats['misses'] += 1
result = func(*args, **kwargs)
self._cache[cache_key] = (result, current_time)
print(f"缓存设置:{func.__name__}")
return result
return wrapper
return decorator
def clear_cache(self):
"""清除缓存"""
self._cache.clear()
print("缓存已清除")
def get_cache_stats(self):
"""获取缓存统计"""
total = self._cache_stats['hits'] + self._cache_stats['misses']
hit_rate = self._cache_stats['hits'] / total if total > 0 else 0
return {
'hits': self._cache_stats['hits'],
'misses': self._cache_stats['misses'],
'hit_rate': f"{hit_rate:.2%}",
'cache_size': len(self._cache)
}
class LogMixin:
"""日志混入"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._logs = []
def log(self, message, level="INFO"):
"""记录日志"""
import datetime
timestamp = datetime.datetime.now().isoformat()
log_entry = f"[{timestamp}] [{level}] {message}"
self._logs.append(log_entry)
print(log_entry)
def get_logs(self):
"""获取日志"""
return self._logs.copy()
def clear_logs(self):
"""清除日志"""
self._logs.clear()
class DatabaseMixin:
"""数据库操作混入"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._db_operations = []
def save(self):
"""保存到数据库"""
operation = f"INSERT INTO {self.__class__.__name__.lower()} ..."
self._db_operations.append(operation)
if hasattr(self, 'log'):
self.log(f"保存操作:{operation}")
return f"已保存:{self}"
def update(self):
"""更新数据库"""
operation = f"UPDATE {self.__class__.__name__.lower()} ..."
self._db_operations.append(operation)
if hasattr(self, 'log'):
self.log(f"更新操作:{operation}")
return f"已更新:{self}"
def delete(self):
"""从数据库删除"""
operation = f"DELETE FROM {self.__class__.__name__.lower()} ..."
self._db_operations.append(operation)
if hasattr(self, 'log'):
self.log(f"删除操作:{operation}")
return f"已删除:{self}"
def get_db_operations(self):
"""获取数据库操作历史"""
return self._db_operations.copy()
# 使用多个混入的复杂类
class AdvancedUser(CacheMixin, LogMixin, DatabaseMixin, TimestampMixin):
"""高级用户类,组合多种功能"""
def __init__(self, username, email):
super().__init__()
self.username = username
self.email = email
self.log(f"创建用户:{username}")
@property # 注意:需要在方法定义后应用装饰器
def profile_data(self):
"""获取用户资料(使用缓存)"""
return self._get_profile_data()
def _get_profile_data(self):
"""内部方法,获取用户资料"""
# 模拟耗时操作
time.sleep(0.5)
return {
'username': self.username,
'email': self.email,
'created_at': self.created_at.isoformat(),
'status': 'active'
}
def __str__(self):
return f"AdvancedUser({self.username})"
# 为方法添加缓存装饰器
def apply_cache_to_method(instance, method_name, expire_time=300):
"""动态为方法添加缓存"""
original_method = getattr(instance, method_name)
cached_method = instance.cached_method(expire_time)(original_method)
setattr(instance, method_name, cached_method)
# 使用示例
user = AdvancedUser("charlie", "charlie@example.com")
# 应用缓存到特定方法
apply_cache_to_method(user, '_get_profile_data', expire_time=10)
# 测试缓存功能
print("第一次获取资料:")
profile1 = user._get_profile_data()
print(profile1)
print("\n第二次获取资料(应该命中缓存):")
profile2 = user._get_profile_data()
print(profile2)
# 查看缓存统计
print("\n缓存统计:", user.get_cache_stats())
# 测试数据库操作
user.save()
user.update()
# 查看日志
print("\n操作日志:")
for log in user.get_logs():
print(log)
# 查看数据库操作历史
print("\n数据库操作:", user.get_db_operations())

方法解析顺序(MRO):理解多重继承

当使用多重继承时,Python使用C3线性化算法确定方法解析顺序。

MRO的重要性

class A:
def method(self):
print("A.method")
class B(A):
def method(self):
print("B.method")
super().method()
class C(A):
def method(self):
print("C.method")
super().method()
class D(B, C):
def method(self):
print("D.method")
super().method()
# 查看方法解析顺序
print("MRO:", D.__mro__)
# 输出:(<class '__main__.D'>, <class '__main__.B'>, <class '__main__.C'>, <class '__main__.A'>, <class 'object'>)
print("MRO列表:", [cls.__name__ for cls in D.__mro__])
# 输出:['D', 'B', 'C', 'A', 'object']
# 调用方法,观察执行顺序
d = D()
d.method()
# 输出:
# D.method
# B.method
# C.method
# A.method

设计良好的混入体系

class EventMixin:
"""事件混入基类"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._event_handlers = {}
def on(self, event_name, handler):
"""注册事件处理器"""
if event_name not in self._event_handlers:
self._event_handlers[event_name] = []
self._event_handlers[event_name].append(handler)
def emit(self, event_name, *args, **kwargs):
"""触发事件"""
if event_name in self._event_handlers:
for handler in self._event_handlers[event_name]:
handler(*args, **kwargs)
class ValidatedModel(ValidationMixin, EventMixin):
"""经过验证的模型基类"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.on('before_save', self._before_save)
self.on('after_save', self._after_save)
def _before_save(self):
print("保存前验证...")
self.validate()
def _after_save(self):
print("保存后清理...")
def save(self):
self.emit('before_save')
# 执行保存逻辑
print("正在保存...")
self.emit('after_save')
return True
class TrackedModel(TimestampMixin, LogMixin, ValidatedModel):
"""可追踪的模型"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.log("模型已创建")
def save(self):
self.log("开始保存模型")
result = super().save()
self.touch() # 更新时间戳
self.log("模型保存完成")
return result
class BlogPost(TrackedModel):
"""博客文章类"""
def __init__(self, title, content, author):
super().__init__()
self.title = title
self.content = content
self.author = author
def validate_title(self):
if not self.title or len(self.title.strip()) < 5:
raise ValueError("标题长度至少5个字符")
def validate_content(self):
if not self.content or len(self.content.strip()) < 10:
raise ValueError("内容长度至少10个字符")
def __str__(self):
return f"BlogPost('{self.title}' by {self.author})"
# 查看MRO
print("BlogPost MRO:")
for i, cls in enumerate(BlogPost.__mro__):
print(f" {i+1}. {cls.__name__}")
# 使用示例
post = BlogPost("Python高级特性", "这是一篇关于Python面向对象高级特性的文章...", "作者")
# 添加自定义事件处理器
post.on('before_save', lambda: print("自定义:保存前检查"))
post.on('after_save', lambda: print("自定义:保存后通知"))
# 保存博客文章
print("\n保存博客文章:")
post.save()
# 查看日志
print("\n操作日志:")
for log in post.get_logs():
print(f" {log}")

🏃‍♂️ 实践练习:插件系统设计

练习3:设计一个插件系统,支持动态加载和卸载功能模块。

from abc import ABC, abstractmethod
from typing import Dict, List, Any
class PluginInterface(ABC):
"""插件接口"""
@property
@abstractmethod
def name(self) -> str:
"""插件名称"""
pass
@property
@abstractmethod
def version(self) -> str:
"""插件版本"""
pass
@abstractmethod
def install(self) -> bool:
"""安装插件"""
pass
@abstractmethod
def uninstall(self) -> bool:
"""卸载插件"""
pass
@abstractmethod
def execute(self, *args, **kwargs) -> Any:
"""执行插件功能"""
pass
class PluginManager:
"""插件管理器"""
def __init__(self):
self._plugins: Dict[str, PluginInterface] = {}
self._installed_plugins: List[str] = []
def register_plugin(self, plugin: PluginInterface):
"""注册插件"""
# TODO: 实现插件注册逻辑
pass
def unregister_plugin(self, plugin_name: str):
"""注销插件"""
# TODO: 实现插件注销逻辑
pass
def install_plugin(self, plugin_name: str):
"""安装插件"""
# TODO: 实现插件安装逻辑
pass
def uninstall_plugin(self, plugin_name: str):
"""卸载插件"""
# TODO: 实现插件卸载逻辑
pass
def execute_plugin(self, plugin_name: str, *args, **kwargs):
"""执行插件"""
# TODO: 实现插件执行逻辑
pass
def get_installed_plugins(self) -> List[str]:
"""获取已安装的插件列表"""
return self._installed_plugins.copy()
# 示例插件实现
class LoggerPlugin(PluginInterface):
@property
def name(self):
return "logger"
@property
def version(self):
return "1.0.0"
def install(self):
print(f"安装日志插件 {self.version}")
return True
def uninstall(self):
print("卸载日志插件")
return True
def execute(self, message, level="INFO"):
import datetime
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"[{timestamp}] [{level}] {message}")
# TODO: 实现完整的插件系统

本节小结

高级类特性为我们提供了构建复杂系统的工具:

  1. 方法类型选择:实例方法、类方法、静态方法各有用途
  2. 抽象基类:定义接口契约,确保实现一致性
  3. 混入模式:功能模块化,提高代码复用性
  4. MRO理解:掌握多重继承的方法解析顺序

下一节我们将学习经典设计模式的Python实现!


🎨 第7.4节:设计模式应用 - 经典设计的Python实现

设计模式是解决特定问题的经典方案。掌握这些模式不仅能提高代码质量,还能让你在团队中更好地沟通设计思路。

单例模式:全局唯一的实例

单例模式确保一个类只有一个实例,并提供全局访问点。常用于配置管理、日志系统、数据库连接池等场景。

方法1:使用__new__方法

class Singleton:
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
# 防止重复初始化
if not Singleton._initialized:
print("单例初始化")
self.value = 0
Singleton._initialized = True
# 使用示例
s1 = Singleton()
s2 = Singleton()
print(s1 is s2) # True,同一个实例
s1.value = 100
print(s2.value) # 100,共享状态

方法2:使用装饰器

def singleton(cls):
"""单例装饰器"""
instances = {}
def get_instance(*args, **kwargs):
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return get_instance
@singleton
class DatabaseConfig:
def __init__(self):
self.host = "localhost"
self.port = 3306
self.database = "myapp"
print(f"创建数据库配置:{self.host}:{self.port}")
def get_connection_string(self):
return f"mysql://{self.host}:{self.port}/{self.database}"
# 使用示例
config1 = DatabaseConfig()
config2 = DatabaseConfig()
print(config1 is config2) # True
print(config1.get_connection_string())

方法3:使用元类(高级)

class SingletonMeta(type):
"""单例元类"""
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]
class Logger(metaclass=SingletonMeta):
def __init__(self):
self.logs = []
print("日志系统初始化")
def log(self, message):
import datetime
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log_entry = f"[{timestamp}] {message}"
self.logs.append(log_entry)
print(log_entry)
def get_logs(self):
return self.logs.copy()
# 使用示例
logger1 = Logger()
logger2 = Logger()
print(logger1 is logger2) # True
logger1.log("用户登录")
logger2.log("数据查询")
print(len(logger1.get_logs())) # 2,共享日志

实际应用:配置管理系统

import json
import os
from threading import Lock
class ConfigManager:
_instance = None
_lock = Lock() # 线程安全
_initialized = False
def __new__(cls):
if cls._instance is None:
with cls._lock: # 双重检查锁定
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not ConfigManager._initialized:
self._config = {}
self._config_file = "app_config.json"
self.load_config()
ConfigManager._initialized = True
def load_config(self):
"""从文件加载配置"""
if os.path.exists(self._config_file):
try:
with open(self._config_file, 'r') as f:
self._config = json.load(f)
print(f"配置已从 {self._config_file} 加载")
except Exception as e:
print(f"加载配置失败:{e}")
self._config = self._get_default_config()
else:
self._config = self._get_default_config()
self.save_config()
def _get_default_config(self):
"""获取默认配置"""
return {
"database": {
"host": "localhost",
"port": 3306,
"name": "myapp"
},
"logging": {
"level": "INFO",
"file": "app.log"
},
"app": {
"debug": False,
"max_connections": 100
}
}
def save_config(self):
"""保存配置到文件"""
try:
with open(self._config_file, 'w') as f:
json.dump(self._config, f, indent=2)
print(f"配置已保存到 {self._config_file}")
except Exception as e:
print(f"保存配置失败:{e}")
def get(self, key, default=None):
"""获取配置值(支持嵌套键,如 'database.host')"""
keys = key.split('.')
value = self._config
for k in keys:
if isinstance(value, dict) and k in value:
value = value[k]
else:
return default
return value
def set(self, key, value):
"""设置配置值(支持嵌套键)"""
keys = key.split('.')
config = self._config
# 导航到最后一级
for k in keys[:-1]:
if k not in config:
config[k] = {}
config = config[k]
# 设置值
config[keys[-1]] = value
self.save_config()
# 使用示例
config = ConfigManager()
print("数据库主机:", config.get("database.host"))
print("调试模式:", config.get("app.debug"))
config.set("app.debug", True)
config.set("database.timeout", 30)
# 另一个地方获取配置(同一个实例)
config2 = ConfigManager()
print("新的调试模式:", config2.get("app.debug")) # True

工厂模式:动态创建对象

工厂模式用于根据条件动态创建不同类型的对象,避免直接使用类名创建对象。

简单工厂模式

from abc import ABC, abstractmethod
# 产品接口
class Notification(ABC):
@abstractmethod
def send(self, message: str) -> str:
pass
# 具体产品
class EmailNotification(Notification):
def __init__(self, recipient):
self.recipient = recipient
def send(self, message):
return f"邮件发送给 {self.recipient}: {message}"
class SMSNotification(Notification):
def __init__(self, phone):
self.phone = phone
def send(self, message):
return f"短信发送到 {self.phone}: {message}"
class PushNotification(Notification):
def __init__(self, device_id):
self.device_id = device_id
def send(self, message):
return f"推送到设备 {self.device_id}: {message}"
# 简单工厂
class NotificationFactory:
@staticmethod
def create_notification(notification_type: str, target: str) -> Notification:
if notification_type.lower() == "email":
return EmailNotification(target)
elif notification_type.lower() == "sms":
return SMSNotification(target)
elif notification_type.lower() == "push":
return PushNotification(target)
else:
raise ValueError(f"不支持的通知类型:{notification_type}")
# 使用示例
factory = NotificationFactory()
email = factory.create_notification("email", "user@example.com")
sms = factory.create_notification("sms", "13800138000")
push = factory.create_notification("push", "device123")
print(email.send("欢迎注册!"))
print(sms.send("验证码:123456"))
print(push.send("您有新消息"))

抽象工厂模式

# 抽象产品家族
class Button(ABC):
@abstractmethod
def render(self):
pass
class Dialog(ABC):
@abstractmethod
def render(self):
pass
# 具体产品家族 - Windows风格
class WindowsButton(Button):
def render(self):
return "渲染Windows风格按钮"
class WindowsDialog(Dialog):
def render(self):
return "渲染Windows风格对话框"
# 具体产品家族 - Mac风格
class MacButton(Button):
def render(self):
return "渲染Mac风格按钮"
class MacDialog(Dialog):
def render(self):
return "渲染Mac风格对话框"
# 抽象工厂
class GUIFactory(ABC):
@abstractmethod
def create_button(self) -> Button:
pass
@abstractmethod
def create_dialog(self) -> Dialog:
pass
# 具体工厂
class WindowsFactory(GUIFactory):
def create_button(self):
return WindowsButton()
def create_dialog(self):
return WindowsDialog()
class MacFactory(GUIFactory):
def create_button(self):
return MacButton()
def create_dialog(self):
return MacDialog()
# 客户端代码
class Application:
def __init__(self, factory: GUIFactory):
self.factory = factory
def create_ui(self):
button = self.factory.create_button()
dialog = self.factory.create_dialog()
return {
'button': button.render(),
'dialog': dialog.render()
}
# 使用示例
import platform
# 根据操作系统选择工厂
if platform.system() == "Windows":
factory = WindowsFactory()
elif platform.system() == "Darwin": # macOS
factory = MacFactory()
else:
factory = WindowsFactory() # 默认
app = Application(factory)
ui_components = app.create_ui()
print(ui_components)

观察者模式:发布-订阅系统

观察者模式定义了对象间的一对多依赖关系,当一个对象状态改变时,所有依赖者都会得到通知。

基础观察者模式

from abc import ABC, abstractmethod
from typing import List
# 观察者接口
class Observer(ABC):
@abstractmethod
def update(self, subject, event_data):
pass
# 主题接口
class Subject(ABC):
def __init__(self):
self._observers: List[Observer] = []
def attach(self, observer: Observer):
"""添加观察者"""
if observer not in self._observers:
self._observers.append(observer)
def detach(self, observer: Observer):
"""移除观察者"""
if observer in self._observers:
self._observers.remove(observer)
def notify(self, event_data=None):
"""通知所有观察者"""
for observer in self._observers:
observer.update(self, event_data)
# 具体主题:新闻发布系统
class NewsAgency(Subject):
def __init__(self):
super().__init__()
self._news = None
self._category = None
def set_news(self, news, category="general"):
"""发布新闻"""
self._news = news
self._category = category
print(f"新闻发布:[{category}] {news}")
self.notify({
'news': news,
'category': category,
'timestamp': self._get_timestamp()
})
def _get_timestamp(self):
import datetime
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def get_news(self):
return self._news
# 具体观察者
class NewsChannel(Observer):
def __init__(self, name):
self.name = name
self.latest_news = None
def update(self, subject, event_data):
self.latest_news = event_data
print(f"📺 {self.name} 收到新闻:{event_data['news']}")
class EmailSubscriber(Observer):
def __init__(self, email):
self.email = email
self.subscribed_categories = set()
def subscribe_category(self, category):
"""订阅特定类别"""
self.subscribed_categories.add(category)
def update(self, subject, event_data):
category = event_data.get('category', 'general')
# 只接收订阅的类别
if not self.subscribed_categories or category in self.subscribed_categories:
print(f"📧 发送邮件到 {self.email}{event_data['news']}")
class MobileApp(Observer):
def __init__(self, app_name):
self.app_name = app_name
self.notification_count = 0
def update(self, subject, event_data):
self.notification_count += 1
print(f"📱 {self.app_name} 推送通知 #{self.notification_count}{event_data['news']}")
# 使用示例
news_agency = NewsAgency()
# 创建观察者
tv_channel = NewsChannel("CNN新闻台")
email_subscriber = EmailSubscriber("user@example.com")
mobile_app = MobileApp("新闻客户端")
# 邮件订阅者只订阅特定类别
email_subscriber.subscribe_category("technology")
email_subscriber.subscribe_category("science")
# 注册观察者
news_agency.attach(tv_channel)
news_agency.attach(email_subscriber)
news_agency.attach(mobile_app)
# 发布不同类别的新闻
print("=== 发布科技新闻 ===")
news_agency.set_news("Python 3.12正式发布", "technology")
print("\n=== 发布体育新闻 ===")
news_agency.set_news("世界杯决赛结果出炉", "sports")
print("\n=== 移除观察者 ===")
news_agency.detach(mobile_app)
news_agency.set_news("重大科学发现", "science")

高级观察者模式:事件系统

from collections import defaultdict
from typing import Callable, Any
import functools
class EventManager:
"""高级事件管理器"""
def __init__(self):
self._handlers = defaultdict(list)
self._middleware = []
def on(self, event_name: str, handler: Callable = None, priority: int = 0):
"""注册事件处理器(支持装饰器用法)"""
def decorator(func):
self._handlers[event_name].append({
'handler': func,
'priority': priority
})
# 按优先级排序(高优先级先执行)
self._handlers[event_name].sort(key=lambda x: x['priority'], reverse=True)
return func
if handler is None:
return decorator
else:
return decorator(handler)
def off(self, event_name: str, handler: Callable = None):
"""移除事件处理器"""
if handler is None:
# 移除所有处理器
if event_name in self._handlers:
del self._handlers[event_name]
else:
# 移除特定处理器
self._handlers[event_name] = [
h for h in self._handlers[event_name]
if h['handler'] != handler
]
def emit(self, event_name: str, *args, **kwargs):
"""触发事件"""
event_data = {
'event_name': event_name,
'args': args,
'kwargs': kwargs,
'stopped': False
}
# 执行中间件
for middleware in self._middleware:
middleware(event_data)
if event_data['stopped']:
return
# 执行事件处理器
for handler_info in self._handlers[event_name]:
try:
result = handler_info['handler'](*args, **kwargs)
if result is False: # 明确返回False时停止传播
break
except Exception as e:
print(f"事件处理器错误:{e}")
def middleware(self, func: Callable):
"""注册中间件"""
self.middleware.append(func)
return func
def once(self, event_name: str, handler: Callable = None):
"""一次性事件处理器"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
self.off(event_name, wrapper) # 执行后自动移除
return result
self.on(event_name, wrapper)
return wrapper
if handler is None:
return decorator
else:
return decorator(handler)
# 使用示例:用户管理系统
user_events = EventManager()
# 注册事件处理器(装饰器方式)
@user_events.on("user_login", priority=10)
def log_user_login(user_id, timestamp):
print(f"📝 记录用户登录:{user_id} at {timestamp}")
@user_events.on("user_login", priority=5)
def send_welcome_email(user_id, timestamp):
print(f"📧 发送欢迎邮件给用户:{user_id}")
@user_events.on("user_login")
def update_last_seen(user_id, timestamp):
print(f"🕒 更新用户最后登录时间:{user_id}")
# 一次性事件处理器
@user_events.once("user_first_login")
def setup_user_profile(user_id):
print(f"🎉 首次登录,设置用户档案:{user_id}")
# 添加中间件
@user_events.middleware
def security_check(event_data):
"""安全检查中间件"""
if event_data['event_name'] == 'user_login':
user_id = event_data['args'][0]
if user_id == 'banned_user':
print("🚫 用户被禁止,阻止事件传播")
event_data['stopped'] = True
# 模拟用户操作
import datetime
timestamp = datetime.datetime.now().isoformat()
print("=== 正常用户登录 ===")
user_events.emit("user_login", "user123", timestamp)
print("\n=== 首次登录用户 ===")
user_events.emit("user_first_login", "newuser456")
user_events.emit("user_first_login", "newuser456") # 第二次不会触发
print("\n=== 被禁用户尝试登录 ===")
user_events.emit("user_login", "banned_user", timestamp)

装饰器模式:动态添加功能

装饰器模式允许在运行时为对象动态添加新功能,而不改变其结构。

函数装饰器应用

import time
import functools
from typing import Callable
# 性能计时装饰器
def timing(func: Callable):
"""计算函数执行时间"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
print(f"⏱️ {func.__name__} 执行时间:{end_time - start_time:.4f}秒")
return result
return wrapper
# 缓存装饰器
def memoize(func: Callable):
"""缓存函数结果"""
cache = {}
@functools.wraps(func)
def wrapper(*args, **kwargs):
# 创建缓存键
key = str(args) + str(sorted(kwargs.items()))
if key in cache:
print(f"💾 缓存命中:{func.__name__}")
return cache[key]
result = func(*args, **kwargs)
cache[key] = result
print(f"💾 缓存设置:{func.__name__}")
return result
wrapper.cache = cache # 暴露缓存以便检查
return wrapper
# 重试装饰器
def retry(max_attempts: int = 3, delay: float = 1.0):
"""重试装饰器"""
def decorator(func: Callable):
@functools.wraps(func)
def wrapper(*args, **kwargs):
for attempt in range(max_attempts):
try:
return func(*args, **kwargs)
except Exception as e:
if attempt == max_attempts - 1:
print(f"❌ {func.__name__} 重试{max_attempts}次后仍然失败")
raise e
print(f"⚠️ {func.__name__}{attempt + 1}次尝试失败,{delay}秒后重试...")
time.sleep(delay)
return wrapper
return decorator
# 使用示例
@timing
@memoize
def fibonacci(n):
"""计算斐波那契数列"""
if n <= 1:
return n
return fibonacci(n - 1) + fibonacci(n - 2)
@retry(max_attempts=3, delay=0.5)
@timing
def unreliable_api_call():
"""模拟不可靠的API调用"""
import random
if random.random() < 0.7: # 70%失败率
raise Exception("API调用失败")
return "API调用成功"
# 测试斐波那契
print("=== 斐波那契计算 ===")
result1 = fibonacci(10) # 第一次计算
result2 = fibonacci(10) # 使用缓存
# 测试重试
print("\n=== API重试测试 ===")
try:
result = unreliable_api_call()
print(f"✅ {result}")
except Exception as e:
print(f"❌ 最终失败:{e}")

类装饰器模式

# 日志装饰器类
class LoggedClass:
"""为类的所有方法添加日志"""
def __init__(self, cls):
self.cls = cls
self._wrap_methods()
def _wrap_methods(self):
"""包装所有公共方法"""
for attr_name in dir(self.cls):
attr = getattr(self.cls, attr_name)
if callable(attr) and not attr_name.startswith('_'):
setattr(self.cls, attr_name, self._log_method(attr))
def _log_method(self, method):
"""为方法添加日志"""
@functools.wraps(method)
def wrapper(*args, **kwargs):
print(f"📝 调用方法:{method.__name__}")
result = method(*args, **kwargs)
print(f"📝 方法完成:{method.__name__}")
return result
return wrapper
def __call__(self, *args, **kwargs):
"""使装饰器可以像类一样实例化"""
return self.cls(*args, **kwargs)
# 权限检查装饰器
def require_permission(permission: str):
"""权限检查装饰器"""
def decorator(cls):
original_init = cls.__init__
@functools.wraps(original_init)
def new_init(self, *args, **kwargs):
self._required_permission = permission
original_init(self, *args, **kwargs)
cls.__init__ = new_init
# 包装所有公共方法
for attr_name in dir(cls):
attr = getattr(cls, attr_name)
if callable(attr) and not attr_name.startswith('_'):
setattr(cls, attr_name, check_permission(attr))
return cls
return decorator
def check_permission(method):
"""检查权限的方法装饰器"""
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
if not hasattr(self, '_required_permission'):
return method(self, *args, **kwargs)
# 模拟权限检查
user_permissions = getattr(self, '_user_permissions', set())
if self._required_permission not in user_permissions:
raise PermissionError(f"缺少权限:{self._required_permission}")
return method(self, *args, **kwargs)
return wrapper
# 使用装饰器
@LoggedClass
@require_permission("admin")
class AdminPanel:
def __init__(self):
self._user_permissions = set()
def set_permissions(self, permissions):
"""设置用户权限"""
self._user_permissions = set(permissions)
def create_user(self, username):
return f"创建用户:{username}"
def delete_user(self, username):
return f"删除用户:{username}"
def view_logs(self):
return "查看系统日志"
# 使用示例
print("=== 权限管理测试 ===")
admin = AdminPanel()
# 没有权限时
try:
admin.create_user("newuser")
except PermissionError as e:
print(f"❌ {e}")
# 设置权限后
admin.set_permissions(["admin", "user"])
result = admin.create_user("newuser")
print(f"✅ {result}")

策略模式:算法族的封装

策略模式定义了算法族,分别封装起来,让它们之间可以互相替换。

基础策略模式

from abc import ABC, abstractmethod
# 策略接口
class SortStrategy(ABC):
@abstractmethod
def sort(self, data):
pass
# 具体策略
class BubbleSort(SortStrategy):
def sort(self, data):
print("使用冒泡排序")
data = data.copy()
n = len(data)
for i in range(n):
for j in range(0, n - i - 1):
if data[j] > data[j + 1]:
data[j], data[j + 1] = data[j + 1], data[j]
return data
class QuickSort(SortStrategy):
def sort(self, data):
print("使用快速排序")
return self._quick_sort(data.copy())
def _quick_sort(self, data):
if len(data) <= 1:
return data
pivot = data[len(data) // 2]
left = [x for x in data if x < pivot]
middle = [x for x in data if x == pivot]
right = [x for x in data if x > pivot]
return self._quick_sort(left) + middle + self._quick_sort(right)
class MergeSort(SortStrategy):
def sort(self, data):
print("使用归并排序")
return self._merge_sort(data.copy())
def _merge_sort(self, data):
if len(data) <= 1:
return data
mid = len(data) // 2
left = self._merge_sort(data[:mid])
right = self._merge_sort(data[mid:])
return self._merge(left, right)
def _merge(self, left, right):
result = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result
# 上下文类
class Sorter:
def __init__(self, strategy: SortStrategy = None):
self._strategy = strategy or BubbleSort()
def set_strategy(self, strategy: SortStrategy):
"""切换策略"""
self._strategy = strategy
def sort(self, data):
"""执行排序"""
return self._strategy.sort(data)
# 使用示例
data = [64, 34, 25, 12, 22, 11, 90]
sorter = Sorter()
print("原始数据:", data)
# 使用不同的排序策略
strategies = [
("冒泡排序", BubbleSort()),
("快速排序", QuickSort()),
("归并排序", MergeSort())
]
for name, strategy in strategies:
sorter.set_strategy(strategy)
result = sorter.sort(data)
print(f"{name}结果:{result}\n")

实际应用:支付系统

from abc import ABC, abstractmethod
from enum import Enum
class PaymentStatus(Enum):
PENDING = "pending"
SUCCESS = "success"
FAILED = "failed"
# 支付策略接口
class PaymentStrategy(ABC):
@abstractmethod
def pay(self, amount: float) -> dict:
pass
@abstractmethod
def refund(self, transaction_id: str, amount: float) -> dict:
pass
# 具体支付策略
class CreditCardPayment(PaymentStrategy):
def __init__(self, card_number: str, cvv: str):
self.card_number = card_number
self.cvv = cvv
def pay(self, amount: float) -> dict:
# 模拟信用卡支付
print(f"💳 使用信用卡支付 ¥{amount}")
return {
'status': PaymentStatus.SUCCESS,
'transaction_id': f"cc_{hash(self.card_number + str(amount))}",
'method': 'credit_card',
'amount': amount
}
def refund(self, transaction_id: str, amount: float) -> dict:
print(f"💳 信用卡退款 ¥{amount}")
return {
'status': PaymentStatus.SUCCESS,
'refund_id': f"rf_{transaction_id}",
'amount': amount
}
class AlipayPayment(PaymentStrategy):
def __init__(self, account: str):
self.account = account
def pay(self, amount: float) -> dict:
print(f"💰 使用支付宝支付 ¥{amount}")
return {
'status': PaymentStatus.SUCCESS,
'transaction_id': f"alipay_{hash(self.account + str(amount))}",
'method': 'alipay',
'amount': amount
}
def refund(self, transaction_id: str, amount: float) -> dict:
print(f"💰 支付宝退款 ¥{amount}")
return {
'status': PaymentStatus.SUCCESS,
'refund_id': f"rf_{transaction_id}",
'amount': amount
}
class WeChatPayment(PaymentStrategy):
def __init__(self, wechat_id: str):
self.wechat_id = wechat_id
def pay(self, amount: float) -> dict:
print(f"💚 使用微信支付 ¥{amount}")
return {
'status': PaymentStatus.SUCCESS,
'transaction_id': f"wechat_{hash(self.wechat_id + str(amount))}",
'method': 'wechat',
'amount': amount
}
def refund(self, transaction_id: str, amount: float) -> dict:
print(f"💚 微信退款 ¥{amount}")
return {
'status': PaymentStatus.SUCCESS,
'refund_id': f"rf_{transaction_id}",
'amount': amount
}
# 支付上下文
class PaymentProcessor:
def __init__(self):
self._strategy = None
self._transaction_history = []
def set_payment_method(self, strategy: PaymentStrategy):
"""设置支付方式"""
self._strategy = strategy
print(f"切换支付方式:{strategy.__class__.__name__}")
def process_payment(self, amount: float) -> dict:
"""处理支付"""
if not self._strategy:
raise ValueError("请先设置支付方式")
result = self._strategy.pay(amount)
self._transaction_history.append(result)
return result
def process_refund(self, transaction_id: str, amount: float) -> dict:
"""处理退款"""
if not self._strategy:
raise ValueError("请先设置支付方式")
return self._strategy.refund(transaction_id, amount)
def get_transaction_history(self):
"""获取交易历史"""
return self._transaction_history.copy()
# 使用示例
processor = PaymentProcessor()
# 创建不同的支付方式
credit_card = CreditCardPayment("1234-5678-9012-3456", "123")
alipay = AlipayPayment("user@example.com")
wechat = WeChatPayment("wxid_123456")
print("=== 在线购物支付流程 ===")
# 模拟购物车支付
items = [
{"name": "Python编程书籍", "price": 89.90},
{"name": "无线鼠标", "price": 129.00},
{"name": "机械键盘", "price": 299.00}
]
total_amount = sum(item["price"] for item in items)
print(f"购物车总金额:¥{total_amount}")
# 用户选择支付方式
payment_methods = [
("信用卡", credit_card),
("支付宝", alipay),
("微信支付", wechat)
]
for name, method in payment_methods:
print(f"\n--- 使用{name}支付 ---")
processor.set_payment_method(method)
result = processor.process_payment(total_amount)
print(f"支付结果:{result}")
# 模拟部分退款
if result['status'] == PaymentStatus.SUCCESS:
refund_result = processor.process_refund(
result['transaction_id'],
items[0]["price"] # 退款第一件商品
)
print(f"退款结果:{refund_result}")
print(f"\n交易历史:{len(processor.get_transaction_history())}笔")

本节小结

我们学习了几个经典的设计模式:

  1. 单例模式:确保全局唯一实例,适用于配置管理、日志系统
  2. 工厂模式:动态创建对象,降低代码耦合度
  3. 观察者模式:实现发布-订阅机制,支持事件驱动编程
  4. 装饰器模式:动态添加功能,不改变原有结构
  5. 策略模式:封装算法族,使算法可以互相替换

这些模式在实际开发中非常常用,掌握它们能让你的代码更加优雅和可维护!


🚀 第7.5节:综合项目案例 - 高级ORM系统设计

在这个综合项目中,我们将设计一个小型的ORM(对象关系映射)系统,集成本章学到的所有高级特性:魔术方法、属性管理、抽象基类、混入模式和设计模式。

项目架构概览

第一步:字段类型系统

from abc import ABC, abstractmethod
from typing import Any, Optional, Dict, List, Callable
from datetime import datetime, date
import re
class Field(ABC):
"""字段基类"""
def __init__(self,
primary_key: bool = False,
nullable: bool = True,
default: Any = None,
validators: List[Callable] = None):
self.primary_key = primary_key
self.nullable = nullable
self.default = default
self.validators = validators or []
self.name = None # 将在Model元类中设置
def __set_name__(self, owner, name):
"""描述符协议:设置字段名"""
self.name = name
def __get__(self, instance, owner):
"""描述符协议:获取值"""
if instance is None:
return self
return instance.__dict__.get(self.name)
def __set__(self, instance, value):
"""描述符协议:设置值"""
# 数据验证
if value is None and not self.nullable:
raise ValueError(f"字段 {self.name} 不能为空")
if value is not None:
value = self.validate_and_convert(value)
instance.__dict__[self.name] = value
@abstractmethod
def validate_and_convert(self, value):
"""验证并转换值"""
pass
@abstractmethod
def to_db_value(self, value):
"""转换为数据库值"""
pass
@abstractmethod
def from_db_value(self, value):
"""从数据库值转换"""
pass
class StringField(Field):
"""字符串字段"""
def __init__(self, max_length: int = 255, **kwargs):
super().__init__(**kwargs)
self.max_length = max_length
def validate_and_convert(self, value):
if not isinstance(value, str):
value = str(value)
if len(value) > self.max_length:
raise ValueError(f"字符串长度不能超过 {self.max_length}")
# 执行自定义验证器
for validator in self.validators:
validator(value)
return value
def to_db_value(self, value):
return value
def from_db_value(self, value):
return value
class IntegerField(Field):
"""整数字段"""
def __init__(self, min_value: int = None, max_value: int = None, **kwargs):
super().__init__(**kwargs)
self.min_value = min_value
self.max_value = max_value
def validate_and_convert(self, value):
if not isinstance(value, int):
try:
value = int(value)
except ValueError:
raise ValueError(f"无法将 {value} 转换为整数")
if self.min_value is not None and value < self.min_value:
raise ValueError(f"值不能小于 {self.min_value}")
if self.max_value is not None and value > self.max_value:
raise ValueError(f"值不能大于 {self.max_value}")
for validator in self.validators:
validator(value)
return value
def to_db_value(self, value):
return value
def from_db_value(self, value):
return int(value) if value is not None else None
class DateTimeField(Field):
"""日期时间字段"""
def __init__(self, auto_now: bool = False, auto_now_add: bool = False, **kwargs):
super().__init__(**kwargs)
self.auto_now = auto_now
self.auto_now_add = auto_now_add
def validate_and_convert(self, value):
if isinstance(value, str):
# 简单的日期时间解析
try:
value = datetime.fromisoformat(value)
except ValueError:
raise ValueError(f"无法解析日期时间:{value}")
if not isinstance(value, datetime):
raise ValueError("必须是datetime对象")
for validator in self.validators:
validator(value)
return value
def to_db_value(self, value):
return value.isoformat() if value else None
def from_db_value(self, value):
return datetime.fromisoformat(value) if value else None
# 字段验证器
class Validators:
@staticmethod
def email_validator(value):
"""邮箱验证器"""
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not re.match(pattern, value):
raise ValueError("无效的邮箱格式")
@staticmethod
def min_length(length):
"""最小长度验证器"""
def validator(value):
if len(value) < length:
raise ValueError(f"长度不能少于 {length}")
return validator
@staticmethod
def max_length(length):
"""最大长度验证器"""
def validator(value):
if len(value) > length:
raise ValueError(f"长度不能超过 {length}")
return validator

第二步:查询构建器

from typing import Dict, List, Any, Union
class QueryBuilder:
"""查询构建器"""
def __init__(self, model_class):
self.model_class = model_class
self._select_fields = ['*']
self._where_conditions = []
self._order_by = []
self._limit_value = None
self._offset_value = None
def select(self, *fields):
"""选择字段"""
self._select_fields = list(fields) if fields else ['*']
return self
def where(self, **conditions):
"""WHERE条件"""
for field, value in conditions.items():
self._where_conditions.append(f"{field} = ?")
return self
def order_by(self, *fields):
"""排序"""
self._order_by.extend(fields)
return self
def limit(self, count):
"""限制数量"""
self._limit_value = count
return self
def offset(self, count):
"""偏移量"""
self._offset_value = count
return self
def build_sql(self) -> str:
"""构建SQL语句"""
table_name = self.model_class._meta.table_name
# SELECT子句
select_clause = f"SELECT {', '.join(self._select_fields)}"
# FROM子句
from_clause = f"FROM {table_name}"
# WHERE子句
where_clause = ""
if self._where_conditions:
where_clause = f"WHERE {' AND '.join(self._where_conditions)}"
# ORDER BY子句
order_clause = ""
if self._order_by:
order_clause = f"ORDER BY {', '.join(self._order_by)}"
# LIMIT子句
limit_clause = ""
if self._limit_value:
limit_clause = f"LIMIT {self._limit_value}"
if self._offset_value:
limit_clause += f" OFFSET {self._offset_value}"
# 组合SQL
sql_parts = [select_clause, from_clause, where_clause, order_clause, limit_clause]
return ' '.join(part for part in sql_parts if part)
class QuerySet:
"""查询集合"""
def __init__(self, model_class, db_manager):
self.model_class = model_class
self.db_manager = db_manager
self._builder = QueryBuilder(model_class)
self._cache = None
def filter(self, **kwargs):
"""过滤"""
new_qs = self._clone()
new_qs._builder.where(**kwargs)
return new_qs
def order_by(self, *fields):
"""排序"""
new_qs = self._clone()
new_qs._builder.order_by(*fields)
return new_qs
def limit(self, count):
"""限制数量"""
new_qs = self._clone()
new_qs._builder.limit(count)
return new_qs
def all(self):
"""获取所有结果"""
if self._cache is None:
sql = self._builder.build_sql()
raw_results = self.db_manager.execute_query(sql)
self._cache = [self._create_instance(row) for row in raw_results]
return self._cache
def first(self):
"""获取第一个结果"""
results = self.limit(1).all()
return results[0] if results else None
def count(self):
"""计算数量"""
builder = QueryBuilder(self.model_class)
builder.select('COUNT(*)')
sql = builder.build_sql()
result = self.db_manager.execute_query(sql)
return result[0][0] if result else 0
def _clone(self):
"""克隆查询集"""
new_qs = QuerySet(self.model_class, self.db_manager)
new_qs._builder = QueryBuilder(self.model_class)
# 复制查询条件
new_qs._builder._select_fields = self._builder._select_fields.copy()
new_qs._builder._where_conditions = self._builder._where_conditions.copy()
new_qs._builder._order_by = self._builder._order_by.copy()
new_qs._builder._limit_value = self._builder._limit_value
new_qs._builder._offset_value = self._builder._offset_value
return new_qs
def _create_instance(self, row_data):
"""从数据库行创建模型实例"""
instance = self.model_class()
fields = self.model_class._meta.fields
for i, (field_name, field) in enumerate(fields.items()):
if i < len(row_data):
db_value = row_data[i]
python_value = field.from_db_value(db_value)
setattr(instance, field_name, python_value)
instance._state.loaded_from_db = True
return instance
def __iter__(self):
"""支持迭代"""
return iter(self.all())
def __len__(self):
"""支持len()"""
return len(self.all())
def __getitem__(self, key):
"""支持索引和切片"""
if isinstance(key, slice):
start, stop = key.start or 0, key.stop
new_qs = self._clone()
new_qs._builder.offset(start)
if stop is not None:
new_qs._builder.limit(stop - start)
return new_qs.all()
else:
return self.all()[key]

第三步:数据库管理器和连接池

import sqlite3
import threading
from queue import Queue, Empty
from contextlib import contextmanager
class ConnectionPool:
"""数据库连接池"""
def __init__(self, database_url: str, max_connections: int = 10):
self.database_url = database_url
self.max_connections = max_connections
self._pool = Queue(maxsize=max_connections)
self._current_connections = 0
self._lock = threading.Lock()
def get_connection(self):
"""获取连接"""
try:
# 尝试从池中获取连接
return self._pool.get_nowait()
except Empty:
# 池中没有连接,创建新连接
with self._lock:
if self._current_connections < self.max_connections:
conn = sqlite3.connect(self.database_url)
conn.row_factory = sqlite3.Row # 使查询结果可以按列名访问
self._current_connections += 1
return conn
else:
# 等待连接可用
return self._pool.get()
def return_connection(self, conn):
"""归还连接"""
self._pool.put(conn)
@contextmanager
def get_connection_context(self):
"""连接上下文管理器"""
conn = self.get_connection()
try:
yield conn
finally:
self.return_connection(conn)
class DatabaseManager:
"""数据库管理器(单例模式)"""
_instance = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not hasattr(self, '_initialized'):
self.connection_pool = None
self._initialized = True
def configure(self, database_url: str, max_connections: int = 10):
"""配置数据库"""
self.connection_pool = ConnectionPool(database_url, max_connections)
def execute_query(self, sql: str, params: tuple = ()) -> List[tuple]:
"""执行查询"""
with self.connection_pool.get_connection_context() as conn:
cursor = conn.cursor()
cursor.execute(sql, params)
return cursor.fetchall()
def execute_update(self, sql: str, params: tuple = ()) -> int:
"""执行更新"""
with self.connection_pool.get_connection_context() as conn:
cursor = conn.cursor()
cursor.execute(sql, params)
conn.commit()
return cursor.rowcount
@contextmanager
def transaction(self):
"""事务管理"""
conn = self.connection_pool.get_connection()
try:
cursor = conn.cursor()
yield cursor
conn.commit()
except Exception:
conn.rollback()
raise
finally:
self.connection_pool.return_connection(conn)

第四步:事件系统和混入

class EventMixin:
"""事件混入"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._event_handlers = {}
def on(self, event_name: str, handler: Callable):
"""注册事件处理器"""
if event_name not in self._event_handlers:
self._event_handlers[event_name] = []
self._event_handlers[event_name].append(handler)
def emit(self, event_name: str, *args, **kwargs):
"""触发事件"""
if event_name in self._event_handlers:
for handler in self._event_handlers[event_name]:
handler(self, *args, **kwargs)
class TimestampMixin:
"""时间戳混入"""
created_at = DateTimeField(auto_now_add=True)
updated_at = DateTimeField(auto_now=True)
class ValidationMixin:
"""验证混入"""
def clean(self):
"""清理和验证数据"""
errors = {}
# 字段级验证
for field_name, field in self._meta.fields.items():
value = getattr(self, field_name)
try:
if value is not None:
field.validate_and_convert(value)
except ValueError as e:
errors[field_name] = str(e)
# 模型级验证
try:
self.validate()
except ValueError as e:
errors['__all__'] = str(e)
if errors:
raise ValueError(f"验证失败:{errors}")
def validate(self):
"""模型级验证(子类可重写)"""
pass
class ModelState:
"""模型状态"""
def __init__(self):
self.loaded_from_db = False
self.dirty_fields = set()
def mark_dirty(self, field_name):
"""标记字段为脏"""
self.dirty_fields.add(field_name)
def mark_clean(self):
"""标记为干净"""
self.dirty_fields.clear()

第五步:Model元类和基类

class ModelMeta(type):
"""Model元类"""
def __new__(mcs, name, bases, attrs, **kwargs):
# 不处理Model基类本身
if name == 'Model':
return super().__new__(mcs, name, bases, attrs)
# 收集字段
fields = {}
for key, value in list(attrs.items()):
if isinstance(value, Field):
fields[key] = value
# 从类属性中移除字段,将通过描述符访问
attrs.pop(key)
# 继承父类字段
for base in bases:
if hasattr(base, '_meta'):
fields.update(base._meta.fields)
# 创建Meta类
meta_attrs = {
'table_name': attrs.get('_table_name', name.lower()),
'fields': fields,
'primary_key': None
}
# 查找主键字段
for field_name, field in fields.items():
if field.primary_key:
meta_attrs['primary_key'] = field_name
break
attrs['_meta'] = type('Meta', (), meta_attrs)
# 创建类
cls = super().__new__(mcs, name, bases, attrs)
# 设置字段名称
for field_name, field in fields.items():
field.name = field_name
return cls
class Model(EventMixin, ValidationMixin, metaclass=ModelMeta):
"""ORM模型基类"""
def __init__(self, **kwargs):
super().__init__()
self._state = ModelState()
# 设置字段值
for field_name, value in kwargs.items():
if field_name in self._meta.fields:
setattr(self, field_name, value)
else:
raise AttributeError(f"模型 {self.__class__.__name__} 没有字段 {field_name}")
# 设置默认值
for field_name, field in self._meta.fields.items():
if field_name not in kwargs:
if field.default is not None:
setattr(self, field_name, field.default)
elif field.auto_now_add and isinstance(field, DateTimeField):
setattr(self, field_name, datetime.now())
def __setattr__(self, name, value):
# 追踪字段变化
if hasattr(self, '_state') and name in self._meta.fields:
self._state.mark_dirty(name)
super().__setattr__(name, value)
def save(self):
"""保存模型"""
self.emit('before_save')
self.clean() # 验证数据
db_manager = DatabaseManager()
# 处理auto_now字段
for field_name, field in self._meta.fields.items():
if isinstance(field, DateTimeField) and field.auto_now:
setattr(self, field_name, datetime.now())
if self._state.loaded_from_db:
# 更新
self._update()
else:
# 插入
self._insert()
self._state.mark_clean()
self._state.loaded_from_db = True
self.emit('after_save')
def _insert(self):
"""插入数据"""
fields = []
values = []
for field_name, field in self._meta.fields.items():
if hasattr(self, field_name):
value = getattr(self, field_name)
if value is not None:
fields.append(field_name)
values.append(field.to_db_value(value))
sql = f"INSERT INTO {self._meta.table_name} ({', '.join(fields)}) VALUES ({', '.join(['?' for _ in values])})"
db_manager = DatabaseManager()
db_manager.execute_update(sql, tuple(values))
def _update(self):
"""更新数据"""
if not self._state.dirty_fields:
return # 没有变化,不需要更新
set_clauses = []
values = []
for field_name in self._state.dirty_fields:
if field_name in self._meta.fields:
field = self._meta.fields[field_name]
value = getattr(self, field_name)
set_clauses.append(f"{field_name} = ?")
values.append(field.to_db_value(value))
primary_key = self._meta.primary_key
pk_value = getattr(self, primary_key)
sql = f"UPDATE {self._meta.table_name} SET {', '.join(set_clauses)} WHERE {primary_key} = ?"
values.append(pk_value)
db_manager = DatabaseManager()
db_manager.execute_update(sql, tuple(values))
def delete(self):
"""删除模型"""
self.emit('before_delete')
primary_key = self._meta.primary_key
pk_value = getattr(self, primary_key)
sql = f"DELETE FROM {self._meta.table_name} WHERE {primary_key} = ?"
db_manager = DatabaseManager()
db_manager.execute_update(sql, (pk_value,))
self.emit('after_delete')
@classmethod
def objects(cls):
"""获取查询集"""
db_manager = DatabaseManager()
return QuerySet(cls, db_manager)
def __str__(self):
primary_key = self._meta.primary_key
pk_value = getattr(self, primary_key, 'None')
return f"{self.__class__.__name__}(id={pk_value})"
def __repr__(self):
return self.__str__()

第六步:使用示例

# 配置数据库
db = DatabaseManager()
db.configure('example.db')
# 定义模型
class User(Model, TimestampMixin):
"""用户模型"""
id = IntegerField(primary_key=True)
username = StringField(
max_length=50,
validators=[Validators.min_length(3)]
)
email = StringField(
max_length=100,
validators=[Validators.email_validator]
)
age = IntegerField(min_value=0, max_value=150)
def validate(self):
"""模型级验证"""
if self.age and self.age < 13:
raise ValueError("用户年龄不能小于13岁")
class Post(Model, TimestampMixin):
"""文章模型"""
id = IntegerField(primary_key=True)
title = StringField(max_length=200)
content = StringField(max_length=5000)
author_id = IntegerField()
published = IntegerField(default=0) # 0=草稿, 1=已发布
# 创建表(简化版本)
def create_tables():
db_manager = DatabaseManager()
# 创建用户表
user_sql = """
CREATE TABLE IF NOT EXISTS user (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL,
email TEXT NOT NULL,
age INTEGER,
created_at TEXT,
updated_at TEXT
)
"""
# 创建文章表
post_sql = """
CREATE TABLE IF NOT EXISTS post (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL,
content TEXT,
author_id INTEGER,
published INTEGER DEFAULT 0,
created_at TEXT,
updated_at TEXT
)
"""
db_manager.execute_update(user_sql)
db_manager.execute_update(post_sql)
# 使用示例
def demo():
create_tables()
print("=== ORM系统演示 ===")
# 创建用户
user = User(
username="pythondev",
email="dev@python.org",
age=25
)
# 添加事件监听
user.on('before_save', lambda model: print(f"准备保存用户:{model.username}"))
user.on('after_save', lambda model: print(f"用户保存完成:{model.username}"))
# 保存用户
user.save()
print(f"创建用户:{user}")
# 创建文章
post1 = Post(
title="Python面向对象编程",
content="这是一篇关于Python OOP的文章...",
author_id=1,
published=1
)
post1.save()
post2 = Post(
title="设计模式详解",
content="设计模式是软件开发的经典解决方案...",
author_id=1,
published=0
)
post2.save()
print(f"创建文章:{post1}")
print(f"创建文章:{post2}")
# 查询示例
print("\n=== 查询演示 ===")
# 查询所有用户
all_users = User.objects().all()
print(f"所有用户:{all_users}")
# 查询已发布的文章
published_posts = Post.objects().filter(published=1).all()
print(f"已发布文章:{published_posts}")
# 查询第一个用户
first_user = User.objects().first()
print(f"第一个用户:{first_user}")
# 更新用户
print("\n=== 更新演示 ===")
first_user.age = 26
first_user.save()
print(f"更新后的用户:{first_user}")
# 链式查询
print("\n=== 链式查询演示 ===")
recent_posts = Post.objects().filter(published=1).order_by('created_at').limit(5).all()
print(f"最近发布的文章:{recent_posts}")
if __name__ == "__main__":
demo()

项目特色总结

这个ORM系统展示了本章学习的所有高级特性:

  1. 魔术方法

    • __new__:单例模式的数据库管理器
    • __setattr__:追踪字段变化
    • __str____repr__:友好的对象表示
    • __iter____len____getitem__:QuerySet的容器协议
  2. 属性管理与描述符

    • Field类实现了完整的描述符协议
    • 自动验证和类型转换
    • __set_name__支持字段名自动设置
  3. 高级类特性

    • 元类ModelMeta自动收集字段定义
    • 类方法objects()作为查询入口
    • 静态方法用于验证器工厂
  4. 抽象基类

    • Field抽象基类定义字段接口
    • 强制子类实现必要方法
  5. 混入模式

    • EventMixin:事件系统
    • TimestampMixin:时间戳管理
    • ValidationMixin:数据验证
  6. 设计模式

    • 单例模式:DatabaseManager
    • 工厂模式:QuerySet创建
    • 观察者模式:事件系统
    • 构建器模式:QueryBuilder

🏃‍♂️ 实践练习:扩展ORM系统

练习4:为ORM系统添加以下功能:

  1. 关系字段:实现ForeignKey字段,支持关联查询
  2. 缓存系统:为查询结果添加缓存机制
  3. 迁移系统:自动生成数据库表结构变更脚本
  4. 序列化器:将模型转换为JSON格式
# TODO: 实现关系字段
class ForeignKey(Field):
def __init__(self, to_model, **kwargs):
# 实现外键关系
pass
# TODO: 实现查询缓存
class CachedQuerySet(QuerySet):
def __init__(self, *args, **kwargs):
# 添加缓存功能
pass
# TODO: 实现模型序列化
class ModelSerializer:
def serialize(self, instance):
# 转换为字典/JSON
pass

🎯 第7章总结

恭喜你完成了Python面向对象编程高级特性的学习!让我们回顾一下本章的核心内容:

知识要点回顾

  1. 魔术方法与运算符重载

    • 对象生命周期管理
    • 字符串表示和格式化
    • 运算符重载实现
    • 容器协议支持
  2. 属性管理与描述符

    • property装饰器的高级用法
    • 描述符协议的三个方法
    • 属性验证和计算属性
    • 属性访问控制机制
  3. 高级类特性

    • 类方法、静态方法的选择
    • 抽象基类的接口设计
    • 混入模式的组合使用
    • 方法解析顺序(MRO)理解
  4. 设计模式应用

    • 单例模式的多种实现
    • 工厂模式的对象创建
    • 观察者模式的事件系统
    • 装饰器模式的功能增强
    • 策略模式的算法封装
  5. 综合项目实践

    • 复杂系统的架构设计
    • 多种特性的综合运用
    • 工程化的代码组织
    • 实际问题的解决方案

编程能力提升

通过本章学习,你已经具备了:

  • 高级面向对象设计能力:能够设计复杂的类体系
  • 代码复用和模块化思维:掌握混入和组合模式
  • 设计模式应用能力:在实际项目中应用经典模式
  • 框架设计思维:理解ORM等框架的实现原理
  • 工程化开发能力:编写可维护、可扩展的代码

下一步学习建议

  1. 深入学习:研究Django、SQLAlchemy等成熟框架的源代码
  2. 实践项目:用学到的知识重构现有项目或开发新项目
  3. 扩展阅读:学习《设计模式》、《重构》等经典书籍
  4. 社区参与:在开源项目中应用和分享你的知识

面向对象编程的高级特性是Python强大功能的重要组成部分。掌握这些特性不仅能让你写出更优雅的代码,还能帮你更好地理解和使用Python生态系统中的各种工具和框架。

继续保持学习的热情,在实践中不断深化对这些概念的理解。下一章我们将学习Python的网络编程和Web开发!