Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@ requests==2.32.4
tqdm==4.66.3
absl-py==0.12.0
browser_cookie3==0.20.1
aiohttp
aiohttp
pydantic
aiofiles
53 changes: 53 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import unittest
from pydantic import ValidationError
from weibo_spider.config import SpiderConfig

class TestSpiderConfig(unittest.TestCase):
def setUp(self):
self.base_config = {
'user_id_list': ['123456'],
'cookie': 'test_cookie',
'filter': 0,
'since_date': '2023-01-01',
'end_date': 'now',
'write_mode': ['csv']
}

def test_valid_config(self):
config = SpiderConfig(**self.base_config)
self.assertEqual(config.user_id_list, ['123456'])
self.assertEqual(config.filter, 0)

def test_invalid_filter(self):
config = self.base_config.copy()
config['filter'] = 2 # Invalid
with self.assertRaises(ValidationError):
SpiderConfig(**config)

def test_invalid_date(self):
config = self.base_config.copy()
config['since_date'] = 'invalid-date'
with self.assertRaises(ValidationError):
SpiderConfig(**config)

def test_invalid_write_mode(self):
config = self.base_config.copy()
config['write_mode'] = ['invalid_mode']
with self.assertRaises(ValidationError):
SpiderConfig(**config)

def test_wait_ranges(self):
config = self.base_config.copy()
# Test start > end
config['random_wait_pages'] = [5, 1]
with self.assertRaises(ValidationError) as cm:
SpiderConfig(**config)
self.assertIn('等待范围起始值不能大于结束值', str(cm.exception))

# Test negative values
config['random_wait_pages'] = [-1, 5]
with self.assertRaises(ValidationError):
SpiderConfig(**config)

if __name__ == '__main__':
unittest.main()
23 changes: 23 additions & 0 deletions tests/test_datetime_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import unittest
from datetime import datetime
from weibo_spider.datetime_util import str_to_time, is_valid_date

class TestDatetimeUtil(unittest.TestCase):
def test_str_to_time(self):
# Test YYYY-MM-DD
dt1 = str_to_time('2023-01-01')
self.assertEqual(dt1, datetime(2023, 1, 1))

# Test YYYY-MM-DD HH:MM
dt2 = str_to_time('2023-01-01 12:30')
self.assertEqual(dt2, datetime(2023, 1, 1, 12, 30))

# Test invalid format (should raise ValueError)
with self.assertRaises(ValueError):
str_to_time('invalid-date')

def test_is_valid_date(self):
self.assertTrue(is_valid_date('2023-01-01'))
self.assertTrue(is_valid_date('2023-01-01 12:30'))
self.assertFalse(is_valid_date('invalid-date'))
self.assertFalse(is_valid_date('2023/01/01')) # Wrong separator
28 changes: 28 additions & 0 deletions tests/test_spider_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import unittest
from unittest.mock import MagicMock, patch
from weibo_spider.spider import Spider
from weibo_spider.config import SpiderConfig

class TestSpiderInit(unittest.TestCase):
@patch('weibo_spider.spider.FLAGS')
def test_init(self, mock_flags):
# Mock flags
mock_flags.user_id_list = None
mock_flags.u = None
mock_flags.output_dir = None

config_dict = {
'user_id_list': ['123'],
'filter': 1,
'since_date': '2023-01-01',
'write_mode': ['csv'],
'cookie': 'cookie'
}
config = SpiderConfig(**config_dict)
spider = Spider(config)
self.assertEqual(spider.filter, 1)
self.assertEqual(spider.since_date, '2023-01-01')
self.assertIsInstance(spider.config, SpiderConfig)

if __name__ == '__main__':
unittest.main()
85 changes: 85 additions & 0 deletions tests/test_writers_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import unittest
import asyncio
import os
import shutil
import sys
from unittest.mock import MagicMock, AsyncMock, patch
from datetime import datetime

# Mock aiofiles before importing modules that use it
sys.modules['aiofiles'] = MagicMock()

from weibo_spider.writer.txt_writer import TxtWriter
from weibo_spider.writer.csv_writer import CsvWriter
from weibo_spider.writer.json_writer import JsonWriter
from weibo_spider.user import User
from weibo_spider.weibo import Weibo

class TestWritersAsync(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.user = User(
id='123456',
nickname='test_user',
weibo_num=100,
following=50,
followers=200
)
self.weibo = Weibo(
id='w1',
user_id='123456',
content='Test Weibo Content',
publish_time='2023-01-01',
up_num=10,
retweet_num=5,
comment_num=2
)
self.weibos = [self.weibo]

@patch('weibo_spider.writer.txt_writer.aiofiles.open')
async def test_txt_writer(self, mock_aio_open):
file_path = 'test.txt'
writer = TxtWriter(file_path, filter=0)

mock_f = AsyncMock()
mock_aio_open.return_value.__aenter__.return_value = mock_f

await writer.write_user(self.user)
mock_f.write.assert_called()

await writer.write_weibo(self.weibos)
mock_f.write.assert_called()

@patch('weibo_spider.writer.csv_writer.aiofiles.open')
async def test_csv_writer(self, mock_aio_open):
file_path = 'test.csv'
writer = CsvWriter(file_path, filter=0)

mock_f = AsyncMock()
mock_aio_open.return_value.__aenter__.return_value = mock_f

await writer.write_weibo(self.weibos)
mock_f.write.assert_called()
# Verify content was written
args, _ = mock_f.write.call_args
self.assertIn('w1', args[0])
self.assertIn('Test Weibo Content', args[0])

@patch('weibo_spider.writer.json_writer.aiofiles.open')
async def test_json_writer(self, mock_aio_open):
file_path = 'test.json'
writer = JsonWriter(file_path)

mock_f = AsyncMock()
mock_aio_open.return_value.__aenter__.return_value = mock_f
# Mock read to return empty valid json
mock_f.read.return_value = '{}'

await writer.write_user(self.user)
await writer.write_weibo(self.weibos)

mock_f.write.assert_called()
args, _ = mock_f.write.call_args
self.assertIn('test_user', args[0])

if __name__ == '__main__':
unittest.main()
4 changes: 2 additions & 2 deletions weibo_spider/__main__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import sys
from pathlib import Path

from absl import app
sys.path.append(os.path.abspath(os.path.dirname(os.getcwd())))
sys.path.append(str(Path.cwd().parent.absolute()))
from weibo_spider.spider import main

app.run(main)
127 changes: 127 additions & 0 deletions weibo_spider/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import List, Optional, Union, Dict, Any
from pydantic import BaseModel, Field, field_validator, ConfigDict
from pathlib import Path
from datetime import datetime
from .datetime_util import is_valid_date


class SpiderConfig(BaseModel):
"""微博爬虫配置类"""
model_config = ConfigDict(arbitrary_types_allowed=True)

user_id_list: Union[List[Union[str, Dict[str, str]]], str] = Field(
description="要爬取的微博用户ID列表,可以是ID列表,也可以是包含ID的字典列表,或txt文件路径。"
)
cookie: str = Field(description="微博的Cookie,用于身份验证。")
filter: int = Field(default=0, description="过滤类型,0表示抓取全部微博,1表示只抓取原创微博。")
since_date: Union[int, str] = Field(
default=0,
description="抓取起始时间,形式为YYYY-MM-DD或YYYY-MM-DD HH:MM,或整数(表示从今天起的前n天)。"
)
end_date: str = Field(
default="now",
description="抓取结束时间,形式为YYYY-MM-DD或YYYY-MM-DD HH:MM,或'now'表示当前时间。"
)
random_wait_pages: List[int] = Field(
default_factory=lambda: [1, 5],
description="随机等待页数范围 [min, max],每爬取多少页暂停一次。"
)
random_wait_seconds: List[int] = Field(
default_factory=lambda: [6, 10],
description="随机等待时间范围 [min, max](秒),每次暂停等待的时长。"
)
global_wait: List[List[int]] = Field(
default_factory=lambda: [[1000, 3600]],
description="全局等待配置 [[页数, 秒数], ...],例如每爬取1000页等待3600秒。"
)
write_mode: List[str] = Field(
default_factory=lambda: ["csv"],
description="结果保存类型列表,可包含 'txt', 'csv', 'json', 'mongo', 'mysql', 'sqlite', 'kafka', 'post'。"
)
pic_download: int = Field(default=0, description="是否下载微博图片,0不下载,1下载。")
video_download: int = Field(default=0, description="是否下载微博视频,0不下载,1下载。")
file_download_timeout: List[int] = Field(
default_factory=lambda: [5, 5, 10],
description="文件下载超时设置 [重试次数, 连接超时, 读取超时]。"
)
result_dir_name: int = Field(default=0, description="结果目录命名方式,0使用用户昵称,1使用用户ID。")
mysql_config: Optional[Dict[str, Any]] = Field(default=None, description="MySQL数据库连接配置字典。")
sqlite_config: Optional[str] = Field(default=None, description="SQLite数据库连接路径。")
kafka_config: Optional[Dict[str, Any]] = Field(default=None, description="Kafka配置字典。")
mongo_config: Optional[Dict[str, Any]] = Field(default=None, description="MongoDB配置字典。")
post_config: Optional[Dict[str, Any]] = Field(default=None, description="POST请求配置字典(用于数据推送)。")

@field_validator('filter', 'pic_download', 'video_download')
@classmethod
def check_binary(cls, v: int) -> int:
if v not in (0, 1):
raise ValueError(f'值应为0或1, 得到: {v}')
return v

@field_validator('since_date')
@classmethod
def check_since_date(cls, v: Union[int, str]) -> Union[int, str]:
if isinstance(v, int):
return v
if not is_valid_date(str(v)):
raise ValueError(f'since_date值应为yyyy-mm-dd形式或整数, 得到: {v}')
return v

@field_validator('end_date')

@classmethod

def check_end_date(cls, v: str) -> str:

if v == 'now' or is_valid_date(v):

return v

raise ValueError(f'end_date值应为yyyy-mm-dd形式或"now", 得到: {v}')

@field_validator('random_wait_pages', 'random_wait_seconds')
@classmethod
def check_wait_range(cls, v: List[int]) -> List[int]:
if not isinstance(v, list) or len(v) != 2:
raise ValueError('参数值应为包含两个整数的list类型')
if not all(isinstance(i, int) for i in v):
raise ValueError('列表中的值应为整数类型')
if min(v) < 1:
raise ValueError('列表中的值应大于0')
if v[0] > v[1]:
raise ValueError('等待范围起始值不能大于结束值')
return v

@field_validator('global_wait')
@classmethod
def check_global_wait(cls, v: List[List[int]]) -> List[List[int]]:
for g in v:
if not isinstance(g, list) or len(g) != 2:
raise ValueError('参数内的值应为长度为2的list类型')
if not all(isinstance(i, int) and i >= 1 for i in g):
raise ValueError('列表中的值应为大于0的整数')
return v

@field_validator('write_mode')
@classmethod
def check_write_mode(cls, v: List[str]) -> List[str]:
valid_modes = {'txt', 'csv', 'json', 'mongo', 'mysql', 'sqlite', 'kafka', 'post'}
for mode in v:
if mode not in valid_modes:
raise ValueError(f'{mode}为无效模式')
return v

@field_validator('user_id_list')
@classmethod
def check_user_id_list(cls, v: Union[List[Union[str, Dict[str, str]]], str]) -> Union[List[Union[str, Dict[str, str]]], str]:
if isinstance(v, list):
return v
if not v.endswith('.txt'):
raise ValueError('user_id_list值应为list类型或txt文件路径')

path = Path(v)
if not path.is_absolute():
path = Path.cwd() / v
if not path.is_file():
raise ValueError(f'不存在{path}文件')
return str(path)
Loading
Loading