70 lines
2.2 KiB
Python
70 lines
2.2 KiB
Python
import json
|
|
import os
|
|
import tempfile
|
|
import unittest
|
|
from unittest.mock import patch
|
|
|
|
from src import state_manager
|
|
|
|
|
|
class TestStateManager(unittest.TestCase):
|
|
def setUp(self):
|
|
# 임시 디렉토리 생성
|
|
self.test_dir = tempfile.TemporaryDirectory()
|
|
self.state_file = os.path.join(self.test_dir.name, "bot_state.json")
|
|
|
|
# STATE_FILE 경로 모킹
|
|
self.patcher = patch("src.state_manager.STATE_FILE", self.state_file)
|
|
self.patcher.start()
|
|
|
|
def tearDown(self):
|
|
self.patcher.stop()
|
|
self.test_dir.cleanup()
|
|
|
|
def test_save_and_load_state(self):
|
|
data = {"KRW-BTC": {"max_price": 100000}}
|
|
state_manager.save_state(data)
|
|
|
|
loaded = state_manager.load_state()
|
|
self.assertEqual(loaded, data)
|
|
|
|
def test_get_and_set_value(self):
|
|
state_manager.set_value("KRW-ETH", "max_price", 200000)
|
|
state_manager.set_value("KRW-ETH", "partial_sell_done", True)
|
|
|
|
max_price = state_manager.get_value("KRW-ETH", "max_price")
|
|
partial_sell = state_manager.get_value("KRW-ETH", "partial_sell_done")
|
|
|
|
self.assertEqual(max_price, 200000)
|
|
self.assertTrue(partial_sell)
|
|
|
|
def test_update_max_price_state(self):
|
|
symbol = "KRW-XRP"
|
|
|
|
# 1. 초기값 설정
|
|
state_manager.set_value(symbol, "max_price", 100)
|
|
|
|
# 2. 더 낮은 가격 업데이트 (무시되어야 함)
|
|
state_manager.update_max_price_state(symbol, 90)
|
|
self.assertEqual(state_manager.get_value(symbol, "max_price"), 100)
|
|
|
|
# 3. 더 높은 가격 업데이트 (반영되어야 함)
|
|
state_manager.update_max_price_state(symbol, 110)
|
|
self.assertEqual(state_manager.get_value(symbol, "max_price"), 110)
|
|
|
|
def test_persistence_across_instances(self):
|
|
# 파일에 저장
|
|
state_manager.set_value("KRW-SOL", "test_key", "test_value")
|
|
|
|
# 파일이 실제로 존재하는지 확인
|
|
self.assertTrue(os.path.exists(self.state_file))
|
|
|
|
# 파일을 직접 읽어서 확인
|
|
with open(self.state_file, encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
self.assertEqual(data["KRW-SOL"]["test_key"], "test_value")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|