单元测试—Mockito的使用


Mockito 是一个模拟测试框架,主要功能是在单元测试中模拟类/对象的行为。

单元测试 – Mockito的使用

需要测试的方法如下:

public void recallMessage(RecallMessageDTO recallMessageDTO) {
    log.info("[融云][单条消息撤回]请求参数: [recallMessageDTO={}]", recallMessageDTO);
    RecallMessage recallMessage = new RecallMessage()
            .setSenderId(recallMessageDTO.getFromUserId())
            .setTargetId(recallMessageDTO.getTargetId())
            .setuId(recallMessageDTO.getMessageUID())
            .setSentTime(recallMessageDTO.getSentTime())
            .setIsAdmin(recallMessageDTO.getIsAdmin())
            .setIsDelete(recallMessageDTO.getIsDelete())
            .setExtra(recallMessageDTO.getExtra())
            .setDisablePush(recallMessageDTO.getDisablePush());
 
    Result result = null;
 
    try{
        result = rongCloud.message.msgPrivate.recall(recallMessage);
        log.info("[融云][单条消息撤回]请求结果: [result={}]", result);
 
        if (Objects.isNull(result)) {
            throw new BusinessException("[融云][单条消息撤回]请求异常");
        }
 
        if (!RongCloudCodeEnum.SUCCESS.getCode().equals(result.getCode())) {
            throw new BusinessException(String.format("融云聊天室单条消息撤回失败: %s", result.getErrorMessage()));
        }
 
    } catch (Exception e) {
        log.error("[融云][单条消息撤回]请求失败:[recallMessageDTO={}]", recallMessageDTO, e);
        throw new BusinessException("[融云][单条消息撤回]异常", e);
    }
}

这个方法竟然是我写的,现在一看真是垃圾。try-catch这是干了啥抹杀了作案现场,含泪以他做一下单元测试吧。

在上边的方法中有以下情况需要考虑

  1. result返回值为空的情况
  2. result返回值不是success的情况
  3. rongCloud调用异常的情况
  4. 方法正常执行

针对rongCloud的情况,我们需要Mock message和private并赋值给rongCloud。

相应的单元测试代码如下:

public class RongCloudServiceTest {
     
    @InjectMocks
    private RongCloudService rongCloudService;
 
    @Mock
    private RongCloud rongCloud;
 
    @Rule
    public ExpectedException thrown= ExpectedException.none();
 
    @Before
    public void setUp() {
        MockitoAnnotations.initMocks(this);
 
        Private privateMsg = Mockito.mock(Private.class);
        Message message = Mockito.mock(Message.class);
 
        message.msgPrivate = privateMsg;
        rongCloud.message = message;
    }
 
    @Test
    public void recallMessage() throws Exception {
        Mockito.when(rongCloud.message.msgPrivate.recall(Mockito.any(RecallMessage.class)))
                .thenReturn(new ResponseResult(RongCloudCodeEnum.SUCCESS.getCode(), RongCloudCodeEnum.SUCCESS.getValue()));
 
        rongCloudService.recallMessage(new RecallMessageDTO());
    }
 
    @Test
    public void recallMessageNull() throws Exception {
        thrown.expect(BusinessException.class);
        thrown.expectMessage("[融云][单条消息撤回]异常");
 
        Mockito.when(rongCloud.message.msgPrivate.recall(Mockito.any(RecallMessage.class))).thenReturn(null);
 
        rongCloudService.recallMessage(new RecallMessageDTO());
    }
 
    @Test
    public void recallMessageError() throws Exception {
        thrown.expect(BusinessException.class);
        thrown.expectMessage("[融云][单条消息撤回]异常");
 
        Mockito.when(rongCloud.message.msgPrivate.recall(Mockito.any(RecallMessage.class)))
                .thenReturn(new ResponseResult(RongCloudCodeEnum.SERVER_ERROR.getCode(), RongCloudCodeEnum.SERVER_ERROR.getValue()));
 
        rongCloudService.recallMessage(new RecallMessageDTO());
    }
 
    @Test
    public void recallMessageException() throws Exception {
        thrown.expect(BusinessException.class);
        thrown.expectMessage("[融云][单条消息撤回]异常");
 
        Mockito.when(rongCloud.message.msgPrivate.recall(Mockito.any(RecallMessage.class))).thenThrow(new ConnectException("连接超时"));
 
        rongCloudService.recallMessage(new RecallMessageDTO());
    }
}

private方法如何做单元测试

需要测试的方法如下:

private Map<String, Set<String>> getChannelSession(List<Integer> to) {
		Map<String, Set<String>> channelSession = new HashMap<>(4);

		to.stream().filter(Objects::nonNull).forEach(recevier -> {
			Set<String> sessionSet = redisTemplate.opsForSet().members(PushConstant.REDIS_SERVICE_USERID + recevier);
			sessionSet.stream().filter(StringUtils::isNotEmpty).forEach(session -> {
				String channel = (String) redisTemplate.opsForHash().get(REDIS_SERVICE_SESSSION + session, "channel");
				if (StringUtils.isNotEmpty(channel)) {
					if (channelSession.containsKey(channel)) {
						Set<String> pushSessionSet = channelSession.get(channel);
						pushSessionSet.add(session);
					}

					else {
						Set<String> pushSessionSet =  new HashSet<>();
						pushSessionSet.add(session);
						channelSession.put(channel, pushSessionSet);
					}
				}
			});
		});

		return channelSession;
	}

涉及到的Redis中的数据类型为:

我写的单元测试类如下:

@RunWith(SpringJUnit4ClassRunner.class)
@SpringBootTest
public class MessageServiceImplTest {

    @InjectMocks
    private MessageServiceImpl messageService;

    @Mock
    private RedisTemplate redisTemplate;

    @Before
    public void setUp() {

        MockitoAnnotations.initMocks(this);
    }

    @Test
    public void testConvertPackageUsageNoData() {
        Mockito.when(redisTemplate.opsForSet()).thenReturn(Mockito.mock(SetOperations.class));
        Mockito.when(redisTemplate.opsForHash()).thenReturn(Mockito.mock(HashOperations.class));


        Set<String> set1 = new HashSet<>();
        set1.add("4fd23740-4365-46d3-8ce7-f2dc9abd0a2b");
        set1.add("470bec5c-767f-4fde-aea7-4a43449353a4");

        Set<String> set2 = new HashSet<>();
        set2.add("1cefae25-29bc-4250-890b-a2eb7e36fd11");
        set2.add("a7174444-5327-4f6e-8205-8083c3030747");

        Mockito.when(redisTemplate.opsForSet().members(PushConstant.REDIS_SERVICE_USERID + 1)).thenReturn(set1);
        Mockito.when(redisTemplate.opsForSet().members(PushConstant.REDIS_SERVICE_USERID + 2)).thenReturn(set2);
        Mockito.when(redisTemplate.opsForHash().get(REDIS_SERVICE_SESSSION + "4fd23740-4365-46d3-8ce7-f2dc9abd0a2b", "channel")).thenReturn("122.168.0.130:30002");
        Mockito.when(redisTemplate.opsForHash().get(REDIS_SERVICE_SESSSION + "470bec5c-767f-4fde-aea7-4a43449353a4", "channel")).thenReturn("122.168.0.131:30002");
        Mockito.when(redisTemplate.opsForHash().get(REDIS_SERVICE_SESSSION + "1cefae25-29bc-4250-890b-a2eb7e36fd11", "channel")).thenReturn("122.168.0.130:30002");
        Mockito.when(redisTemplate.opsForHash().get(REDIS_SERVICE_SESSSION + "a7174444-5327-4f6e-8205-8083c3030747", "channel")).thenReturn("122.168.0.131:30002");

        List<Integer> toUser = new ArrayList<>(2);
        toUser.add(1);
        toUser.add(2);
        Map<String, Set<String>> map = ReflectionTestUtils.invokeMethod(messageService, "getChannelSession", toUser);
        Assert.assertNotNull(map);
        Assert.assertTrue(!map.isEmpty());
        Assert.assertEquals(2, map.size());
        Assert.assertTrue(map.containsKey("122.168.0.130:30002"));
        Assert.assertTrue(map.containsKey("122.168.0.131:30002"));

        Assert.assertEquals(2, map.get("122.168.0.130:30002").size());
        Assert.assertEquals(2, map.get("122.168.0.131:30002").size());
    }

}

这时候就存在一个问题,如果我们需要测试的方法很多,我们需要在所有的方法中写入以下代码:

​ Mockito.when(redisTemplate.opsForSet()).thenReturn(Mockito.mock(SetOperations.class));
​ Mockito.when(redisTemplate.opsForHash()).thenReturn(Mockito.mock(HashOperations.class));

而我们正确的方式应该时Mock整个RedisTemplate类

@Component
public class RedisTemplateSpy implements BeanPostProcessor {
    @Override
    public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
        return bean;
    }

    @Override
    public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
        if ("redisTemplate".equals(beanName)) {

            RedisTemplate redisTemplate = Mockito.mock(RedisTemplate.class);
            ValueOperations valueOperations = Mockito.mock(ValueOperations.class);
            SetOperations setOperations = Mockito.mock(SetOperations.class);
            HashOperations hashOperations = Mockito.mock(HashOperations.class);
            ListOperations listOperations = Mockito.mock(ListOperations.class);
            ZSetOperations zSetOperations = Mockito.mock(ZSetOperations.class);
            Mockito.when(redisTemplate.opsForSet()).thenReturn(setOperations);
            Mockito.when(redisTemplate.opsForValue()).thenReturn(valueOperations);
            Mockito.when(redisTemplate.opsForHash()).thenReturn(hashOperations);
            Mockito.when(redisTemplate.opsForList()).thenReturn(listOperations);
            Mockito.when(redisTemplate.opsForZSet()).thenReturn(zSetOperations);

            RedisOperations redisOperations = Mockito.mock(RedisOperations.class);
            RedisConnection redisConnection = Mockito.mock(RedisConnection.class);
            RedisConnectionFactory redisConnectionFactory = Mockito.mock(RedisConnectionFactory.class);
            Mockito.when(redisTemplate.getConnectionFactory()).thenReturn(redisConnectionFactory);
            Mockito.when(valueOperations.getOperations()).thenReturn(redisOperations);
            Mockito.when(redisTemplate.getConnectionFactory().getConnection()).thenReturn(redisConnection);

            return redisTemplate;
        }
        return bean;
    }
}

我们在需要使用RedisTemplate的地方直接引用即可,因为我们已经在上面的方法中实现了RedisTemplate的Mock。

@Autowired
private RedisTemplate redisTemplate;

返回值校验

我们对接口http://localhost:8080/manager/user/55424进行单元测试

@RequestMapping(value = "/user/{userId}", method = { RequestMethod.GET })
public Response<Map<String, List<String>>> getUser(@PathVariable Integer userId) {
    log.info("[推送服务GATEWAY] 获取推送服务上单个WS用户 请求参数 [userId {}]", userId);

    Map<String, List<String>> sessionInfo = managerService.getUserSessions(userId);
    return Response.ok(sessionInfo);
}

返回值为:

{
    "version": 0,
    "status": 0,
    "errMsg": "ok",
    "errorMsg": "ok",
    "ts": 1626188639683,
    "data": {
        "55424": [
            "7c921ff9-9a76-4ad8-92e4-02de140b40ad"
            ]
    }
}

相应的单元测试类为:

@WebMvcTest({ ManagerController.class})
public class ManagerControllerTest {

    private MockMvc mockMvc;

    @InjectMocks
    private ManagerController controller;

    @Mock
    private ManagerServiceImpl service;

    @Before
    public void before() {
        MockitoAnnotations.initMocks(this);
        this.mockMvc = MockMvcBuilders
                .standaloneSetup(controller)
                .build();
    }

    @Test
    public void insertSupplierBrandStrategyValid() throws Exception {

        Map<String, List<String>> map = new HashMap<>();
        List<String> sessionList = new ArrayList<>(2);
        String session1 = UUID.randomUUID().toString();
        String session2 = UUID.randomUUID().toString();
        sessionList.add(session1);
        sessionList.add(session2);
        map.put("1", sessionList);
        when(service.getUserSessions(1)).thenReturn(map);
        MockHttpServletRequestBuilder requestBuilder = MockMvcRequestBuilders
                .get("/manager/user/1")
                .accept(MediaType.APPLICATION_JSON)
                .contentType(MediaType.APPLICATION_JSON);
        MvcResult mvcResult = mockMvc.perform(requestBuilder)
                .andDo(print()) //打印输出发出请求的详细信息
                .andExpect(status().isOk())
                .andExpect(MockMvcResultMatchers.jsonPath("$.data.length()").value(1))
                .andExpect(MockMvcResultMatchers.jsonPath("$.data", Matchers.hasKey("1")))
                .andExpect(MockMvcResultMatchers.jsonPath("$.data", Matchers.hasEntry("1",sessionList )))
                .andReturn();

        System.out.println(mvcResult.getResponse().getContentAsString());
    }

}

henReturn、doReturn 等函数支持链式调用,用来指定函数特定调用次数时的行为。

// 让第1次调用返回 100,第2次调用返回 200
when(exampleService.add(1, 2)).thenReturn(100).thenReturn(200);

文章作者: WangQingLei
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 WangQingLei !
 上一篇
单元测试—Mockito学习 单元测试—Mockito学习
Mockito 是一个模拟测试框架,主要功能是在单元测试中模拟类/对象的行为。
2021-07-08
下一篇 
单元测试—参数化测试 单元测试—参数化测试
在写单元测试的时候经常会遇到一种情况,针对某个方法使用多组入参进行测试,这时可以每组入参写一个测试方法,但代码重复率高不优雅,而 junit 从 4.0 开始提供了一种叫做参数化测试的方式专门处理这样情况。
2021-07-03
  目录